diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
104
104
  return (self.weight * hidden_states).to(input_dtype)
105
105
 
106
106
 
107
- def _config_to_kwargs(args):
108
- common_kwargs = {
109
- "dtype": args.torch_dtype,
110
- }
111
- return common_kwargs
112
-
113
-
114
107
  class CoreAttention(torch.nn.Module):
115
108
  def __init__(self, config: ChatGLMConfig, layer_number):
116
109
  super(CoreAttention, self).__init__()
@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
314
307
  self.qkv_hidden_size,
315
308
  bias=config.add_bias_linear or config.add_qkv_bias,
316
309
  device=device,
317
- **_config_to_kwargs(config),
318
310
  )
319
311
 
320
312
  self.core_attention = CoreAttention(config, self.layer_number)
@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
325
317
  config.hidden_size,
326
318
  bias=config.add_bias_linear,
327
319
  device=device,
328
- **_config_to_kwargs(config),
329
320
  )
330
321
 
331
322
  def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
449
440
  config.ffn_hidden_size * 2,
450
441
  bias=self.add_bias,
451
442
  device=device,
452
- **_config_to_kwargs(config),
453
443
  )
454
444
 
455
445
  def swiglu(x):
@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
459
449
  self.activation_func = swiglu
460
450
 
461
451
  # Project back to h.
462
- self.dense_4h_to_h = nn.Linear(
463
- config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
464
- )
452
+ self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
465
453
 
466
454
  def forward(self, hidden_states):
467
455
  # [s, b, 4hp]
@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
488
476
 
489
477
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
490
478
  # Layernorm on the input data.
491
- self.input_layernorm = LayerNormFunc(
492
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
493
- )
479
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
494
480
 
495
481
  # Self attention.
496
482
  self.self_attention = SelfAttention(config, layer_number, device=device)
497
483
  self.hidden_dropout = config.hidden_dropout
498
484
 
499
485
  # Layernorm on the attention output
500
- self.post_attention_layernorm = LayerNormFunc(
501
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
502
- )
486
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
503
487
 
504
488
  # MLP
505
489
  self.mlp = MLP(config, device=device)
@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
569
553
  if self.post_layer_norm:
570
554
  LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571
555
  # Final layer norm before output.
572
- self.final_layernorm = LayerNormFunc(
573
- config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
574
- )
556
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
575
557
 
576
558
  self.gradient_checkpointing = False
577
559
 
@@ -605,7 +587,7 @@ class GLMTransformer(torch.nn.Module):
605
587
 
606
588
  layer = self._get_layer(index)
607
589
  if torch.is_grad_enabled() and self.gradient_checkpointing:
608
- layer_ret = torch.utils.checkpoint.checkpoint(
590
+ layer_ret = self._gradient_checkpointing_func(
609
591
  layer, hidden_states, attention_mask, rotary_pos_emb, kv_caches[index], use_cache
610
592
  )
611
593
  else:
@@ -666,10 +648,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
666
648
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
667
649
  return position_ids
668
650
 
669
- def _set_gradient_checkpointing(self, module, value=False):
670
- if isinstance(module, GLMTransformer):
671
- module.gradient_checkpointing = value
672
-
673
651
 
674
652
  def default_init(cls, *args, **kwargs):
675
653
  return cls(*args, **kwargs)
@@ -683,9 +661,7 @@ class Embedding(torch.nn.Module):
683
661
 
684
662
  self.hidden_size = config.hidden_size
685
663
  # Word embeddings (parallel).
686
- self.word_embeddings = nn.Embedding(
687
- config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
688
- )
664
+ self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
689
665
  self.fp32_residual_connection = config.fp32_residual_connection
690
666
 
691
667
  def forward(self, input_ids):
@@ -788,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
788
764
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
789
765
  )
790
766
 
791
- self.rotary_pos_emb = RotaryEmbedding(
792
- rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
793
- )
767
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
794
768
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
795
769
  self.output_layer = init_method(
796
770
  nn.Linear,
797
771
  config.hidden_size,
798
772
  config.padded_vocab_size,
799
773
  bias=False,
800
- dtype=config.torch_dtype,
801
774
  **init_kwargs,
802
775
  )
803
776
  self.pre_seq_len = config.pre_seq_len
@@ -30,6 +30,7 @@ from ...schedulers import LCMScheduler
30
30
  from ...utils import (
31
31
  USE_PEFT_BACKEND,
32
32
  deprecate,
33
+ is_torch_xla_available,
33
34
  logging,
34
35
  replace_example_docstring,
35
36
  scale_lora_layers,
@@ -40,6 +41,13 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
41
  from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
41
42
 
42
43
 
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
43
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
52
 
45
53
 
@@ -226,7 +234,7 @@ class LatentConsistencyModelImg2ImgPipeline(
226
234
  " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
227
235
  )
228
236
 
229
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
237
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
230
238
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
231
239
 
232
240
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
@@ -952,6 +960,9 @@ class LatentConsistencyModelImg2ImgPipeline(
952
960
  step_idx = i // getattr(self.scheduler, "order", 1)
953
961
  callback(step_idx, t, latents)
954
962
 
963
+ if XLA_AVAILABLE:
964
+ xm.mark_step()
965
+
955
966
  denoised = denoised.to(prompt_embeds.dtype)
956
967
  if not output_type == "latent":
957
968
  image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
@@ -29,6 +29,7 @@ from ...schedulers import LCMScheduler
29
29
  from ...utils import (
30
30
  USE_PEFT_BACKEND,
31
31
  deprecate,
32
+ is_torch_xla_available,
32
33
  logging,
33
34
  replace_example_docstring,
34
35
  scale_lora_layers,
@@ -39,8 +40,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
39
40
  from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
40
41
 
41
42
 
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
42
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
51
 
52
+
44
53
  EXAMPLE_DOC_STRING = """
45
54
  Examples:
46
55
  ```py
@@ -209,7 +218,7 @@ class LatentConsistencyModelPipeline(
209
218
  feature_extractor=feature_extractor,
210
219
  image_encoder=image_encoder,
211
220
  )
212
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
221
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
213
222
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
214
223
  self.register_to_config(requires_safety_checker=requires_safety_checker)
215
224
 
@@ -881,6 +890,9 @@ class LatentConsistencyModelPipeline(
881
890
  step_idx = i // getattr(self.scheduler, "order", 1)
882
891
  callback(step_idx, t, latents)
883
892
 
893
+ if XLA_AVAILABLE:
894
+ xm.mark_step()
895
+
884
896
  denoised = denoised.to(prompt_embeds.dtype)
885
897
  if not output_type == "latent":
886
898
  image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
@@ -25,10 +25,19 @@ from transformers.utils import logging
25
25
 
26
26
  from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
27
27
  from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
28
+ from ...utils import is_torch_xla_available
28
29
  from ...utils.torch_utils import randn_tensor
29
30
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
30
31
 
31
32
 
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+
32
41
  class LDMTextToImagePipeline(DiffusionPipeline):
33
42
  r"""
34
43
  Pipeline for text-to-image generation using latent diffusion.
@@ -202,6 +211,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
202
211
  # compute the previous noisy sample x_t -> x_t-1
203
212
  latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
204
213
 
214
+ if XLA_AVAILABLE:
215
+ xm.mark_step()
216
+
205
217
  # scale and decode the image latents with vae
206
218
  latents = 1 / self.vqvae.config.scaling_factor * latents
207
219
  image = self.vqvae.decode(latents).sample
@@ -532,10 +544,6 @@ class LDMBertPreTrainedModel(PreTrainedModel):
532
544
  if module.padding_idx is not None:
533
545
  module.weight.data[module.padding_idx].zero_()
534
546
 
535
- def _set_gradient_checkpointing(self, module, value=False):
536
- if isinstance(module, (LDMBertEncoder,)):
537
- module.gradient_checkpointing = value
538
-
539
547
  @property
540
548
  def dummy_inputs(self):
541
549
  pad_token = self.config.pad_token_id
@@ -676,15 +684,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
676
684
  if output_hidden_states:
677
685
  encoder_states = encoder_states + (hidden_states,)
678
686
  if torch.is_grad_enabled() and self.gradient_checkpointing:
679
-
680
- def create_custom_forward(module):
681
- def custom_forward(*inputs):
682
- return module(*inputs, output_attentions)
683
-
684
- return custom_forward
685
-
686
- layer_outputs = torch.utils.checkpoint.checkpoint(
687
- create_custom_forward(encoder_layer),
687
+ layer_outputs = self._gradient_checkpointing_func(
688
+ encoder_layer,
688
689
  hidden_states,
689
690
  attention_mask,
690
691
  (head_mask[idx] if head_mask is not None else None),
@@ -15,11 +15,19 @@ from ...schedulers import (
15
15
  LMSDiscreteScheduler,
16
16
  PNDMScheduler,
17
17
  )
18
- from ...utils import PIL_INTERPOLATION
18
+ from ...utils import PIL_INTERPOLATION, is_torch_xla_available
19
19
  from ...utils.torch_utils import randn_tensor
20
20
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
21
21
 
22
22
 
23
+ if is_torch_xla_available():
24
+ import torch_xla.core.xla_model as xm
25
+
26
+ XLA_AVAILABLE = True
27
+ else:
28
+ XLA_AVAILABLE = False
29
+
30
+
23
31
  def preprocess(image):
24
32
  w, h = image.size
25
33
  w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
@@ -174,6 +182,9 @@ class LDMSuperResolutionPipeline(DiffusionPipeline):
174
182
  # compute the previous noisy sample x_t -> x_t-1
175
183
  latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
176
184
 
185
+ if XLA_AVAILABLE:
186
+ xm.mark_step()
187
+
177
188
  # decode the image latents with the VQVAE
178
189
  image = self.vqvae.decode(latents).sample
179
190
  image = torch.clamp(image, -1.0, 1.0)
@@ -30,8 +30,10 @@ from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import (
31
31
  BACKENDS_MAPPING,
32
32
  BaseOutput,
33
+ deprecate,
33
34
  is_bs4_available,
34
35
  is_ftfy_available,
36
+ is_torch_xla_available,
35
37
  logging,
36
38
  replace_example_docstring,
37
39
  )
@@ -39,8 +41,16 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
39
41
  from ...video_processor import VideoProcessor
40
42
 
41
43
 
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
42
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
52
 
53
+
44
54
  if is_bs4_available():
45
55
  from bs4 import BeautifulSoup
46
56
 
@@ -180,7 +190,7 @@ class LattePipeline(DiffusionPipeline):
180
190
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
181
191
  )
182
192
 
183
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
193
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
184
194
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
185
195
 
186
196
  # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
@@ -592,6 +602,10 @@ class LattePipeline(DiffusionPipeline):
592
602
  def num_timesteps(self):
593
603
  return self._num_timesteps
594
604
 
605
+ @property
606
+ def current_timestep(self):
607
+ return self._current_timestep
608
+
595
609
  @property
596
610
  def interrupt(self):
597
611
  return self._interrupt
@@ -623,7 +637,7 @@ class LattePipeline(DiffusionPipeline):
623
637
  clean_caption: bool = True,
624
638
  mask_feature: bool = True,
625
639
  enable_temporal_attentions: bool = True,
626
- decode_chunk_size: Optional[int] = None,
640
+ decode_chunk_size: int = 14,
627
641
  ) -> Union[LattePipelineOutput, Tuple]:
628
642
  """
629
643
  Function invoked when calling the pipeline for generation.
@@ -719,6 +733,7 @@ class LattePipeline(DiffusionPipeline):
719
733
  negative_prompt_embeds,
720
734
  )
721
735
  self._guidance_scale = guidance_scale
736
+ self._current_timestep = None
722
737
  self._interrupt = False
723
738
 
724
739
  # 2. Default height and width to transformer
@@ -780,6 +795,7 @@ class LattePipeline(DiffusionPipeline):
780
795
  if self.interrupt:
781
796
  continue
782
797
 
798
+ self._current_timestep = t
783
799
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
784
800
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
785
801
 
@@ -788,10 +804,11 @@ class LattePipeline(DiffusionPipeline):
788
804
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
789
805
  # This would be a good case for the `match` statement (Python 3.10+)
790
806
  is_mps = latent_model_input.device.type == "mps"
807
+ is_npu = latent_model_input.device.type == "npu"
791
808
  if isinstance(current_timestep, float):
792
- dtype = torch.float32 if is_mps else torch.float64
809
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
793
810
  else:
794
- dtype = torch.int32 if is_mps else torch.int64
811
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
795
812
  current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
796
813
  elif len(current_timestep.shape) == 0:
797
814
  current_timestep = current_timestep[None].to(latent_model_input.device)
@@ -800,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
800
817
 
801
818
  # predict noise model_output
802
819
  noise_pred = self.transformer(
803
- latent_model_input,
820
+ hidden_states=latent_model_input,
804
821
  encoder_hidden_states=prompt_embeds,
805
822
  timestep=current_timestep,
806
823
  enable_temporal_attentions=enable_temporal_attentions,
@@ -836,8 +853,20 @@ class LattePipeline(DiffusionPipeline):
836
853
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
837
854
  progress_bar.update()
838
855
 
839
- if not output_type == "latents":
840
- video = self.decode_latents(latents, video_length, decode_chunk_size=14)
856
+ if XLA_AVAILABLE:
857
+ xm.mark_step()
858
+
859
+ self._current_timestep = None
860
+
861
+ if output_type == "latents":
862
+ deprecation_message = (
863
+ "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead."
864
+ )
865
+ deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False)
866
+ output_type = "latent"
867
+
868
+ if not output_type == "latent":
869
+ video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
841
870
  video = self.video_processor.postprocess_video(video=video, output_type=output_type)
842
871
  else:
843
872
  video = latents
@@ -19,6 +19,7 @@ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
19
19
  from ...utils import (
20
20
  USE_PEFT_BACKEND,
21
21
  deprecate,
22
+ is_torch_xla_available,
22
23
  logging,
23
24
  replace_example_docstring,
24
25
  scale_lora_layers,
@@ -29,26 +30,32 @@ from ..pipeline_utils import DiffusionPipeline
29
30
  from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
30
31
 
31
32
 
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
32
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
41
 
42
+
34
43
  EXAMPLE_DOC_STRING = """
35
44
  Examples:
36
45
  ```py
37
- >>> import PIL
38
- >>> import requests
39
46
  >>> import torch
40
- >>> from io import BytesIO
41
47
 
42
48
  >>> from diffusers import LEditsPPPipelineStableDiffusion
43
49
  >>> from diffusers.utils import load_image
44
50
 
45
51
  >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
46
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
52
+ ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
47
53
  ... )
54
+ >>> pipe.enable_vae_tiling()
48
55
  >>> pipe = pipe.to("cuda")
49
56
 
50
57
  >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
51
- >>> image = load_image(img_url).convert("RGB")
58
+ >>> image = load_image(img_url).resize((512, 512))
52
59
 
53
60
  >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
54
61
 
@@ -152,7 +159,7 @@ class LeditsGaussianSmoothing:
152
159
 
153
160
  # The gaussian kernel is the product of the gaussian function of each dimension.
154
161
  kernel = 1
155
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
162
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
156
163
  for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
157
164
  mean = (size - 1) / 2
158
165
  kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -318,7 +325,7 @@ class LEditsPPPipelineStableDiffusion(
318
325
  "The scheduler has been changed to DPMSolverMultistepScheduler."
319
326
  )
320
327
 
321
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
328
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
322
329
  deprecation_message = (
323
330
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
324
331
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -332,7 +339,7 @@ class LEditsPPPipelineStableDiffusion(
332
339
  new_config["steps_offset"] = 1
333
340
  scheduler._internal_dict = FrozenDict(new_config)
334
341
 
335
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
342
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
336
343
  deprecation_message = (
337
344
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
338
345
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -361,10 +368,14 @@ class LEditsPPPipelineStableDiffusion(
361
368
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
362
369
  )
363
370
 
364
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
365
- version.parse(unet.config._diffusers_version).base_version
366
- ) < version.parse("0.9.0.dev0")
367
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
371
+ is_unet_version_less_0_9_0 = (
372
+ unet is not None
373
+ and hasattr(unet.config, "_diffusers_version")
374
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
375
+ )
376
+ is_unet_sample_size_less_64 = (
377
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
378
+ )
368
379
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
369
380
  deprecation_message = (
370
381
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -391,7 +402,7 @@ class LEditsPPPipelineStableDiffusion(
391
402
  safety_checker=safety_checker,
392
403
  feature_extractor=feature_extractor,
393
404
  )
394
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
405
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
395
406
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
396
407
  self.register_to_config(requires_safety_checker=requires_safety_checker)
397
408
 
@@ -706,6 +717,35 @@ class LEditsPPPipelineStableDiffusion(
706
717
  def cross_attention_kwargs(self):
707
718
  return self._cross_attention_kwargs
708
719
 
720
+ def enable_vae_slicing(self):
721
+ r"""
722
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
723
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
724
+ """
725
+ self.vae.enable_slicing()
726
+
727
+ def disable_vae_slicing(self):
728
+ r"""
729
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
730
+ computing decoding in one step.
731
+ """
732
+ self.vae.disable_slicing()
733
+
734
+ def enable_vae_tiling(self):
735
+ r"""
736
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
737
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
738
+ processing larger images.
739
+ """
740
+ self.vae.enable_tiling()
741
+
742
+ def disable_vae_tiling(self):
743
+ r"""
744
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
745
+ computing decoding in one step.
746
+ """
747
+ self.vae.disable_tiling()
748
+
709
749
  @torch.no_grad()
710
750
  @replace_example_docstring(EXAMPLE_DOC_STRING)
711
751
  def __call__(
@@ -1182,6 +1222,9 @@ class LEditsPPPipelineStableDiffusion(
1182
1222
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1183
1223
  progress_bar.update()
1184
1224
 
1225
+ if XLA_AVAILABLE:
1226
+ xm.mark_step()
1227
+
1185
1228
  # 8. Post-processing
1186
1229
  if not output_type == "latent":
1187
1230
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
@@ -1271,6 +1314,8 @@ class LEditsPPPipelineStableDiffusion(
1271
1314
  [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
1272
1315
  and respective VAE reconstruction(s).
1273
1316
  """
1317
+ if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
1318
+ raise ValueError("height and width must be a factor of 32.")
1274
1319
  # Reset attn processor, we do not want to store attn maps during inversion
1275
1320
  self.unet.set_attn_processor(AttnProcessor())
1276
1321
 
@@ -1349,6 +1394,9 @@ class LEditsPPPipelineStableDiffusion(
1349
1394
 
1350
1395
  progress_bar.update()
1351
1396
 
1397
+ if XLA_AVAILABLE:
1398
+ xm.mark_step()
1399
+
1352
1400
  self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
1353
1401
  zs = zs.flip(0)
1354
1402
  self.zs = zs
@@ -1360,6 +1408,12 @@ class LEditsPPPipelineStableDiffusion(
1360
1408
  image = self.image_processor.preprocess(
1361
1409
  image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1362
1410
  )
1411
+ height, width = image.shape[-2:]
1412
+ if height % 32 != 0 or width % 32 != 0:
1413
+ raise ValueError(
1414
+ "Image height and width must be a factor of 32. "
1415
+ "Consider down-sampling the input using the `height` and `width` parameters"
1416
+ )
1363
1417
  resized = self.image_processor.postprocess(image=image, output_type="pil")
1364
1418
 
1365
1419
  if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5: