diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -16,10 +16,11 @@ import html
16
16
  import inspect
17
17
  import re
18
18
  import urllib.parse as ul
19
+ import warnings
19
20
  from typing import Callable, Dict, List, Optional, Tuple, Union
20
21
 
21
22
  import torch
22
- from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
23
24
 
24
25
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
25
26
  from ...image_processor import PixArtImageProcessor
@@ -30,6 +31,7 @@ from ...utils import (
30
31
  BACKENDS_MAPPING,
31
32
  is_bs4_available,
32
33
  is_ftfy_available,
34
+ is_torch_xla_available,
33
35
  logging,
34
36
  replace_example_docstring,
35
37
  )
@@ -40,11 +42,20 @@ from ..pixart_alpha.pipeline_pixart_alpha import (
40
42
  ASPECT_RATIO_1024_BIN,
41
43
  )
42
44
  from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
45
+ from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
43
46
  from .pag_utils import PAGMixin
44
47
 
45
48
 
49
+ if is_torch_xla_available():
50
+ import torch_xla.core.xla_model as xm
51
+
52
+ XLA_AVAILABLE = True
53
+ else:
54
+ XLA_AVAILABLE = False
55
+
46
56
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
57
 
58
+
48
59
  if is_bs4_available():
49
60
  from bs4 import BeautifulSoup
50
61
 
@@ -149,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
149
160
 
150
161
  def __init__(
151
162
  self,
152
- tokenizer: AutoTokenizer,
153
- text_encoder: AutoModelForCausalLM,
163
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
164
+ text_encoder: Gemma2PreTrainedModel,
154
165
  vae: AutoencoderDC,
155
166
  transformer: SanaTransformer2DModel,
156
167
  scheduler: FlowMatchEulerDiscreteScheduler,
@@ -162,7 +173,11 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
162
173
  tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
163
174
  )
164
175
 
165
- self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
176
+ self.vae_scale_factor = (
177
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
178
+ if hasattr(self, "vae") and self.vae is not None
179
+ else 8
180
+ )
166
181
  self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
167
182
 
168
183
  self.set_pag_applied_layers(
@@ -170,6 +185,35 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
170
185
  pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
171
186
  )
172
187
 
188
+ def enable_vae_slicing(self):
189
+ r"""
190
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
191
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
192
+ """
193
+ self.vae.enable_slicing()
194
+
195
+ def disable_vae_slicing(self):
196
+ r"""
197
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
198
+ computing decoding in one step.
199
+ """
200
+ self.vae.disable_slicing()
201
+
202
+ def enable_vae_tiling(self):
203
+ r"""
204
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
205
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
206
+ processing larger images.
207
+ """
208
+ self.vae.enable_tiling()
209
+
210
+ def disable_vae_tiling(self):
211
+ r"""
212
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
213
+ computing decoding in one step.
214
+ """
215
+ self.vae.disable_tiling()
216
+
173
217
  def encode_prompt(
174
218
  self,
175
219
  prompt: Union[str, List[str]],
@@ -224,7 +268,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
224
268
  else:
225
269
  batch_size = prompt_embeds.shape[0]
226
270
 
227
- self.tokenizer.padding_side = "right"
271
+ if getattr(self, "tokenizer", None) is not None:
272
+ self.tokenizer.padding_side = "right"
228
273
 
229
274
  # See Section 3.1. of the paper.
230
275
  max_length = max_sequence_length
@@ -597,7 +642,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
597
642
  negative_prompt_attention_mask: Optional[torch.Tensor] = None,
598
643
  output_type: Optional[str] = "pil",
599
644
  return_dict: bool = True,
600
- clean_caption: bool = True,
645
+ clean_caption: bool = False,
601
646
  use_resolution_binning: bool = True,
602
647
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
603
648
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -713,7 +758,9 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
713
758
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
714
759
 
715
760
  if use_resolution_binning:
716
- if self.transformer.config.sample_size == 64:
761
+ if self.transformer.config.sample_size == 128:
762
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
763
+ elif self.transformer.config.sample_size == 64:
717
764
  aspect_ratio_bin = ASPECT_RATIO_2048_BIN
718
765
  elif self.transformer.config.sample_size == 32:
719
766
  aspect_ratio_bin = ASPECT_RATIO_1024_BIN
@@ -863,11 +910,21 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
863
910
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
864
911
  progress_bar.update()
865
912
 
913
+ if XLA_AVAILABLE:
914
+ xm.mark_step()
915
+
866
916
  if output_type == "latent":
867
917
  image = latents
868
918
  else:
869
919
  latents = latents.to(self.vae.dtype)
870
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
920
+ try:
921
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
922
+ except torch.cuda.OutOfMemoryError as e:
923
+ warnings.warn(
924
+ f"{e}. \n"
925
+ f"Try to use VAE tiling for large images. For example: \n"
926
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
927
+ )
871
928
  if use_resolution_binning:
872
929
  image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
873
930
 
@@ -27,6 +27,7 @@ from ...schedulers import KarrasDiffusionSchedulers
27
27
  from ...utils import (
28
28
  USE_PEFT_BACKEND,
29
29
  deprecate,
30
+ is_torch_xla_available,
30
31
  logging,
31
32
  replace_example_docstring,
32
33
  scale_lora_layers,
@@ -39,8 +40,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
39
40
  from .pag_utils import PAGMixin
40
41
 
41
42
 
43
+ if is_torch_xla_available():
44
+ import torch_xla.core.xla_model as xm
45
+
46
+ XLA_AVAILABLE = True
47
+ else:
48
+ XLA_AVAILABLE = False
49
+
42
50
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
51
 
52
+
44
53
  EXAMPLE_DOC_STRING = """
45
54
  Examples:
46
55
  ```py
@@ -207,7 +216,7 @@ class StableDiffusionPAGPipeline(
207
216
  ):
208
217
  super().__init__()
209
218
 
210
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
219
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
211
220
  deprecation_message = (
212
221
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
213
222
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -221,7 +230,7 @@ class StableDiffusionPAGPipeline(
221
230
  new_config["steps_offset"] = 1
222
231
  scheduler._internal_dict = FrozenDict(new_config)
223
232
 
224
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
233
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
225
234
  deprecation_message = (
226
235
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
227
236
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -250,10 +259,14 @@ class StableDiffusionPAGPipeline(
250
259
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
251
260
  )
252
261
 
253
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
254
- version.parse(unet.config._diffusers_version).base_version
255
- ) < version.parse("0.9.0.dev0")
256
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
262
+ is_unet_version_less_0_9_0 = (
263
+ unet is not None
264
+ and hasattr(unet.config, "_diffusers_version")
265
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
266
+ )
267
+ is_unet_sample_size_less_64 = (
268
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
269
+ )
257
270
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
258
271
  deprecation_message = (
259
272
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -281,7 +294,7 @@ class StableDiffusionPAGPipeline(
281
294
  feature_extractor=feature_extractor,
282
295
  image_encoder=image_encoder,
283
296
  )
284
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
297
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
285
298
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
286
299
  self.register_to_config(requires_safety_checker=requires_safety_checker)
287
300
 
@@ -1034,6 +1047,9 @@ class StableDiffusionPAGPipeline(
1034
1047
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1035
1048
  progress_bar.update()
1036
1049
 
1050
+ if XLA_AVAILABLE:
1051
+ xm.mark_step()
1052
+
1037
1053
  if not output_type == "latent":
1038
1054
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1039
1055
  0
@@ -200,9 +200,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
200
200
  transformer=transformer,
201
201
  scheduler=scheduler,
202
202
  )
203
- self.vae_scale_factor = (
204
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
205
- )
203
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
206
204
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
207
205
  self.tokenizer_max_length = (
208
206
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -377,9 +375,9 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
377
375
  negative_prompt_2 (`str` or `List[str]`, *optional*):
378
376
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
379
377
  `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
380
- negative_prompt_2 (`str` or `List[str]`, *optional*):
378
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
381
379
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
382
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
380
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
383
381
  prompt_embeds (`torch.FloatTensor`, *optional*):
384
382
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
385
383
  provided, text embeddings will be generated from `prompt` input argument.
@@ -216,9 +216,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
216
216
  transformer=transformer,
217
217
  scheduler=scheduler,
218
218
  )
219
- self.vae_scale_factor = (
220
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
221
- )
219
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
222
220
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
223
221
  self.tokenizer_max_length = (
224
222
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -393,9 +391,9 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
393
391
  negative_prompt_2 (`str` or `List[str]`, *optional*):
394
392
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
395
393
  `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
396
- negative_prompt_2 (`str` or `List[str]`, *optional*):
394
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
397
395
  The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
398
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
396
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
399
397
  prompt_embeds (`torch.FloatTensor`, *optional*):
400
398
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
401
399
  provided, text embeddings will be generated from `prompt` input argument.
@@ -26,6 +26,7 @@ from ...models.unets.unet_motion_model import MotionAdapter
26
26
  from ...schedulers import KarrasDiffusionSchedulers
27
27
  from ...utils import (
28
28
  USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
29
30
  logging,
30
31
  replace_example_docstring,
31
32
  scale_lora_layers,
@@ -40,8 +41,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
40
41
  from .pag_utils import PAGMixin
41
42
 
42
43
 
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
43
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
52
 
53
+
45
54
  EXAMPLE_DOC_STRING = """
46
55
  Examples:
47
56
  ```py
@@ -147,7 +156,7 @@ class AnimateDiffPAGPipeline(
147
156
  feature_extractor=feature_extractor,
148
157
  image_encoder=image_encoder,
149
158
  )
150
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
159
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
151
160
  self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
152
161
 
153
162
  self.set_pag_applied_layers(pag_applied_layers)
@@ -847,6 +856,9 @@ class AnimateDiffPAGPipeline(
847
856
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
848
857
  progress_bar.update()
849
858
 
859
+ if XLA_AVAILABLE:
860
+ xm.mark_step()
861
+
850
862
  # 9. Post processing
851
863
  if output_type == "latent":
852
864
  video = latents
@@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers
30
30
  from ...utils import (
31
31
  USE_PEFT_BACKEND,
32
32
  deprecate,
33
+ is_torch_xla_available,
33
34
  logging,
34
35
  replace_example_docstring,
35
36
  scale_lora_layers,
@@ -42,8 +43,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
42
43
  from .pag_utils import PAGMixin
43
44
 
44
45
 
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
45
53
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
54
 
55
+
47
56
  EXAMPLE_DOC_STRING = """
48
57
  Examples:
49
58
  ```py
@@ -202,7 +211,7 @@ class StableDiffusionPAGImg2ImgPipeline(
202
211
  ):
203
212
  super().__init__()
204
213
 
205
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
214
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
206
215
  deprecation_message = (
207
216
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
208
217
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -216,7 +225,7 @@ class StableDiffusionPAGImg2ImgPipeline(
216
225
  new_config["steps_offset"] = 1
217
226
  scheduler._internal_dict = FrozenDict(new_config)
218
227
 
219
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
228
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
220
229
  deprecation_message = (
221
230
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
222
231
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -245,10 +254,14 @@ class StableDiffusionPAGImg2ImgPipeline(
245
254
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
246
255
  )
247
256
 
248
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
249
- version.parse(unet.config._diffusers_version).base_version
250
- ) < version.parse("0.9.0.dev0")
251
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
257
+ is_unet_version_less_0_9_0 = (
258
+ unet is not None
259
+ and hasattr(unet.config, "_diffusers_version")
260
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
261
+ )
262
+ is_unet_sample_size_less_64 = (
263
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
264
+ )
252
265
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
253
266
  deprecation_message = (
254
267
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -276,7 +289,7 @@ class StableDiffusionPAGImg2ImgPipeline(
276
289
  feature_extractor=feature_extractor,
277
290
  image_encoder=image_encoder,
278
291
  )
279
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
292
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
280
293
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
281
294
  self.register_to_config(requires_safety_checker=requires_safety_checker)
282
295
 
@@ -1066,6 +1079,9 @@ class StableDiffusionPAGImg2ImgPipeline(
1066
1079
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1067
1080
  progress_bar.update()
1068
1081
 
1082
+ if XLA_AVAILABLE:
1083
+ xm.mark_step()
1084
+
1069
1085
  if not output_type == "latent":
1070
1086
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1071
1087
  0
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
28
28
  from ...utils import (
29
29
  USE_PEFT_BACKEND,
30
30
  deprecate,
31
+ is_torch_xla_available,
31
32
  logging,
32
33
  replace_example_docstring,
33
34
  scale_lora_layers,
@@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
40
41
  from .pag_utils import PAGMixin
41
42
 
42
43
 
44
+ if is_torch_xla_available():
45
+ import torch_xla.core.xla_model as xm
46
+
47
+ XLA_AVAILABLE = True
48
+ else:
49
+ XLA_AVAILABLE = False
50
+
43
51
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
44
52
 
53
+
45
54
  EXAMPLE_DOC_STRING = """
46
55
  Examples:
47
56
  ```py
@@ -234,7 +243,7 @@ class StableDiffusionPAGInpaintPipeline(
234
243
  ):
235
244
  super().__init__()
236
245
 
237
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
246
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
238
247
  deprecation_message = (
239
248
  f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
240
249
  f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -248,7 +257,7 @@ class StableDiffusionPAGInpaintPipeline(
248
257
  new_config["steps_offset"] = 1
249
258
  scheduler._internal_dict = FrozenDict(new_config)
250
259
 
251
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
260
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
252
261
  deprecation_message = (
253
262
  f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
254
263
  " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -277,10 +286,14 @@ class StableDiffusionPAGInpaintPipeline(
277
286
  " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
278
287
  )
279
288
 
280
- is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
281
- version.parse(unet.config._diffusers_version).base_version
282
- ) < version.parse("0.9.0.dev0")
283
- is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
289
+ is_unet_version_less_0_9_0 = (
290
+ unet is not None
291
+ and hasattr(unet.config, "_diffusers_version")
292
+ and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0")
293
+ )
294
+ is_unet_sample_size_less_64 = (
295
+ unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
296
+ )
284
297
  if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
285
298
  deprecation_message = (
286
299
  "The configuration file of the unet has set the default `sample_size` to smaller than"
@@ -308,7 +321,7 @@ class StableDiffusionPAGInpaintPipeline(
308
321
  feature_extractor=feature_extractor,
309
322
  image_encoder=image_encoder,
310
323
  )
311
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
324
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
312
325
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
313
326
  self.mask_processor = VaeImageProcessor(
314
327
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -670,7 +683,7 @@ class StableDiffusionPAGInpaintPipeline(
670
683
  if padding_mask_crop is not None:
671
684
  if not isinstance(image, PIL.Image.Image):
672
685
  raise ValueError(
673
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
686
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
674
687
  )
675
688
  if not isinstance(mask_image, PIL.Image.Image):
676
689
  raise ValueError(
@@ -678,7 +691,7 @@ class StableDiffusionPAGInpaintPipeline(
678
691
  f" {type(mask_image)}."
679
692
  )
680
693
  if output_type != "pil":
681
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
694
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
682
695
 
683
696
  if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
684
697
  raise ValueError(
@@ -1178,7 +1191,7 @@ class StableDiffusionPAGInpaintPipeline(
1178
1191
  f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1179
1192
  f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1180
1193
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1181
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1194
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1182
1195
  " `pipeline.unet` or your `mask_image` or `image` input."
1183
1196
  )
1184
1197
  elif num_channels_unet != 4:
@@ -1318,6 +1331,9 @@ class StableDiffusionPAGInpaintPipeline(
1318
1331
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1319
1332
  progress_bar.update()
1320
1333
 
1334
+ if XLA_AVAILABLE:
1335
+ xm.mark_step()
1336
+
1321
1337
  if not output_type == "latent":
1322
1338
  condition_kwargs = {}
1323
1339
  if isinstance(self.vae, AsymmetricAutoencoderKL):
@@ -275,10 +275,14 @@ class StableDiffusionXLPAGPipeline(
275
275
  feature_extractor=feature_extractor,
276
276
  )
277
277
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
278
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
278
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
279
279
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
280
280
 
281
- self.default_sample_size = self.unet.config.sample_size
281
+ self.default_sample_size = (
282
+ self.unet.config.sample_size
283
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
284
+ else 128
285
+ )
282
286
 
283
287
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
284
288
 
@@ -415,7 +419,9 @@ class StableDiffusionXLPAGPipeline(
415
419
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
416
420
 
417
421
  # We are only ALWAYS interested in the pooled output of the final text encoder
418
- pooled_prompt_embeds = prompt_embeds[0]
422
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
423
+ pooled_prompt_embeds = prompt_embeds[0]
424
+
419
425
  if clip_skip is None:
420
426
  prompt_embeds = prompt_embeds.hidden_states[-2]
421
427
  else:
@@ -474,8 +480,10 @@ class StableDiffusionXLPAGPipeline(
474
480
  uncond_input.input_ids.to(device),
475
481
  output_hidden_states=True,
476
482
  )
483
+
477
484
  # We are only ALWAYS interested in the pooled output of the final text encoder
478
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
485
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
486
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
479
487
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
480
488
 
481
489
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -298,7 +298,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
298
298
  )
299
299
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
300
300
  self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
301
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
301
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
302
302
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
303
303
 
304
304
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -436,7 +436,9 @@ class StableDiffusionXLPAGImg2ImgPipeline(
436
436
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
437
437
 
438
438
  # We are only ALWAYS interested in the pooled output of the final text encoder
439
- pooled_prompt_embeds = prompt_embeds[0]
439
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
440
+ pooled_prompt_embeds = prompt_embeds[0]
441
+
440
442
  if clip_skip is None:
441
443
  prompt_embeds = prompt_embeds.hidden_states[-2]
442
444
  else:
@@ -495,8 +497,10 @@ class StableDiffusionXLPAGImg2ImgPipeline(
495
497
  uncond_input.input_ids.to(device),
496
498
  output_hidden_states=True,
497
499
  )
500
+
498
501
  # We are only ALWAYS interested in the pooled output of the final text encoder
499
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
502
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
503
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
500
504
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
501
505
 
502
506
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -314,7 +314,7 @@ class StableDiffusionXLPAGInpaintPipeline(
314
314
  )
315
315
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
316
316
  self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
317
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
317
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
318
318
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
319
319
  self.mask_processor = VaeImageProcessor(
320
320
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -526,7 +526,9 @@ class StableDiffusionXLPAGInpaintPipeline(
526
526
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
527
527
 
528
528
  # We are only ALWAYS interested in the pooled output of the final text encoder
529
- pooled_prompt_embeds = prompt_embeds[0]
529
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
530
+ pooled_prompt_embeds = prompt_embeds[0]
531
+
530
532
  if clip_skip is None:
531
533
  prompt_embeds = prompt_embeds.hidden_states[-2]
532
534
  else:
@@ -585,8 +587,10 @@ class StableDiffusionXLPAGInpaintPipeline(
585
587
  uncond_input.input_ids.to(device),
586
588
  output_hidden_states=True,
587
589
  )
590
+
588
591
  # We are only ALWAYS interested in the pooled output of the final text encoder
589
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
592
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
593
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
590
594
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
591
595
 
592
596
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -733,7 +737,7 @@ class StableDiffusionXLPAGInpaintPipeline(
733
737
  if padding_mask_crop is not None:
734
738
  if not isinstance(image, PIL.Image.Image):
735
739
  raise ValueError(
736
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
740
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
737
741
  )
738
742
  if not isinstance(mask_image, PIL.Image.Image):
739
743
  raise ValueError(
@@ -741,7 +745,7 @@ class StableDiffusionXLPAGInpaintPipeline(
741
745
  f" {type(mask_image)}."
742
746
  )
743
747
  if output_type != "pil":
744
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
748
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
745
749
 
746
750
  if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
747
751
  raise ValueError(
@@ -1505,7 +1509,7 @@ class StableDiffusionXLPAGInpaintPipeline(
1505
1509
  f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1506
1510
  f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1507
1511
  f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1508
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1512
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
1509
1513
  " `pipeline.unet` or your `mask_image` or `image` input."
1510
1514
  )
1511
1515
  elif num_channels_unet != 4: