diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1797 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTextModelWithProjection,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ )
28
+
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import (
31
+ FromSingleFileMixin,
32
+ IPAdapterMixin,
33
+ StableDiffusionXLLoraLoaderMixin,
34
+ TextualInversionLoaderMixin,
35
+ )
36
+ from ...models import AutoencoderKL, UNet2DConditionModel
37
+ from ...models.attention_processor import (
38
+ Attention,
39
+ AttnProcessor,
40
+ AttnProcessor2_0,
41
+ LoRAAttnProcessor2_0,
42
+ LoRAXFormersAttnProcessor,
43
+ XFormersAttnProcessor,
44
+ )
45
+ from ...models.lora import adjust_lora_scale_text_encoder
46
+ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
47
+ from ...utils import (
48
+ USE_PEFT_BACKEND,
49
+ is_invisible_watermark_available,
50
+ is_torch_xla_available,
51
+ logging,
52
+ replace_example_docstring,
53
+ scale_lora_layers,
54
+ unscale_lora_layers,
55
+ )
56
+ from ...utils.torch_utils import randn_tensor
57
+ from ..pipeline_utils import DiffusionPipeline
58
+ from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
59
+
60
+
61
+ if is_invisible_watermark_available():
62
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
+
64
+ if is_torch_xla_available():
65
+ import torch_xla.core.xla_model as xm
66
+
67
+ XLA_AVAILABLE = True
68
+ else:
69
+ XLA_AVAILABLE = False
70
+
71
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
72
+
73
+ EXAMPLE_DOC_STRING = """
74
+ Examples:
75
+ ```py
76
+ >>> import torch
77
+ >>> import PIL
78
+ >>> import requests
79
+ >>> from io import BytesIO
80
+
81
+ >>> from diffusers import LEditsPPPipelineStableDiffusionXL
82
+
83
+ >>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
84
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
85
+ ... )
86
+ >>> pipe = pipe.to("cuda")
87
+
88
+ >>> def download_image(url):
89
+ ... response = requests.get(url)
90
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
91
+
92
+ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
93
+ >>> image = download_image(img_url)
94
+
95
+ >>> _ = pipe.invert(
96
+ ... image = image,
97
+ ... num_inversion_steps=50,
98
+ ... skip=0.2
99
+ ... )
100
+
101
+ >>> edited_image = pipe(
102
+ ... editing_prompt=["tennis ball","tomato"],
103
+ ... reverse_editing_direction=[True,False],
104
+ ... edit_guidance_scale=[5.0,10.0],
105
+ ... edit_threshold=[0.9,0.85],
106
+ ).images[0]
107
+ ```
108
+ """
109
+
110
+
111
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsAttentionStore
112
+ class LeditsAttentionStore:
113
+ @staticmethod
114
+ def get_empty_store():
115
+ return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
116
+
117
+ def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
118
+ # attn.shape = batch_size * head_size, seq_len query, seq_len_key
119
+ if attn.shape[1] <= self.max_size:
120
+ bs = 1 + int(PnP) + editing_prompts
121
+ skip = 2 if PnP else 1 # skip PnP & unconditional
122
+ attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
123
+ source_batch_size = int(attn.shape[1] // bs)
124
+ self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet)
125
+
126
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
127
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
128
+
129
+ self.step_store[key].append(attn)
130
+
131
+ def between_steps(self, store_step=True):
132
+ if store_step:
133
+ if self.average:
134
+ if len(self.attention_store) == 0:
135
+ self.attention_store = self.step_store
136
+ else:
137
+ for key in self.attention_store:
138
+ for i in range(len(self.attention_store[key])):
139
+ self.attention_store[key][i] += self.step_store[key][i]
140
+ else:
141
+ if len(self.attention_store) == 0:
142
+ self.attention_store = [self.step_store]
143
+ else:
144
+ self.attention_store.append(self.step_store)
145
+
146
+ self.cur_step += 1
147
+ self.step_store = self.get_empty_store()
148
+
149
+ def get_attention(self, step: int):
150
+ if self.average:
151
+ attention = {
152
+ key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
153
+ }
154
+ else:
155
+ assert step is not None
156
+ attention = self.attention_store[step]
157
+ return attention
158
+
159
+ def aggregate_attention(
160
+ self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int
161
+ ):
162
+ out = [[] for x in range(self.batch_size)]
163
+ if isinstance(res, int):
164
+ num_pixels = res**2
165
+ resolution = (res, res)
166
+ else:
167
+ num_pixels = res[0] * res[1]
168
+ resolution = res[:2]
169
+
170
+ for location in from_where:
171
+ for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
172
+ for batch, item in enumerate(bs_item):
173
+ if item.shape[1] == num_pixels:
174
+ cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select]
175
+ out[batch].append(cross_maps)
176
+
177
+ out = torch.stack([torch.cat(x, dim=0) for x in out])
178
+ # average over heads
179
+ out = out.sum(1) / out.shape[1]
180
+ return out
181
+
182
+ def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None):
183
+ self.step_store = self.get_empty_store()
184
+ self.attention_store = []
185
+ self.cur_step = 0
186
+ self.average = average
187
+ self.batch_size = batch_size
188
+ if max_size is None:
189
+ self.max_size = max_resolution**2
190
+ elif max_size is not None and max_resolution is None:
191
+ self.max_size = max_size
192
+ else:
193
+ raise ValueError("Only allowed to set one of max_resolution or max_size")
194
+
195
+
196
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LeditsGaussianSmoothing
197
+ class LeditsGaussianSmoothing:
198
+ def __init__(self, device):
199
+ kernel_size = [3, 3]
200
+ sigma = [0.5, 0.5]
201
+
202
+ # The gaussian kernel is the product of the gaussian function of each dimension.
203
+ kernel = 1
204
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
205
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
206
+ mean = (size - 1) / 2
207
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
208
+
209
+ # Make sure sum of values in gaussian kernel equals 1.
210
+ kernel = kernel / torch.sum(kernel)
211
+
212
+ # Reshape to depthwise convolutional weight
213
+ kernel = kernel.view(1, 1, *kernel.size())
214
+ kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1))
215
+
216
+ self.weight = kernel.to(device)
217
+
218
+ def __call__(self, input):
219
+ """
220
+ Arguments:
221
+ Apply gaussian filter to input.
222
+ input (torch.Tensor): Input to apply gaussian filter on.
223
+ Returns:
224
+ filtered (torch.Tensor): Filtered output.
225
+ """
226
+ return F.conv2d(input, weight=self.weight.to(input.dtype))
227
+
228
+
229
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEDITSCrossAttnProcessor
230
+ class LEDITSCrossAttnProcessor:
231
+ def __init__(self, attention_store, place_in_unet, pnp, editing_prompts):
232
+ self.attnstore = attention_store
233
+ self.place_in_unet = place_in_unet
234
+ self.editing_prompts = editing_prompts
235
+ self.pnp = pnp
236
+
237
+ def __call__(
238
+ self,
239
+ attn: Attention,
240
+ hidden_states,
241
+ encoder_hidden_states,
242
+ attention_mask=None,
243
+ temb=None,
244
+ ):
245
+ batch_size, sequence_length, _ = (
246
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
247
+ )
248
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
249
+
250
+ query = attn.to_q(hidden_states)
251
+
252
+ if encoder_hidden_states is None:
253
+ encoder_hidden_states = hidden_states
254
+ elif attn.norm_cross:
255
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
256
+
257
+ key = attn.to_k(encoder_hidden_states)
258
+ value = attn.to_v(encoder_hidden_states)
259
+
260
+ query = attn.head_to_batch_dim(query)
261
+ key = attn.head_to_batch_dim(key)
262
+ value = attn.head_to_batch_dim(value)
263
+
264
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
265
+ self.attnstore(
266
+ attention_probs,
267
+ is_cross=True,
268
+ place_in_unet=self.place_in_unet,
269
+ editing_prompts=self.editing_prompts,
270
+ PnP=self.pnp,
271
+ )
272
+
273
+ hidden_states = torch.bmm(attention_probs, value)
274
+ hidden_states = attn.batch_to_head_dim(hidden_states)
275
+
276
+ # linear proj
277
+ hidden_states = attn.to_out[0](hidden_states)
278
+ # dropout
279
+ hidden_states = attn.to_out[1](hidden_states)
280
+
281
+ hidden_states = hidden_states / attn.rescale_output_factor
282
+ return hidden_states
283
+
284
+
285
+ class LEditsPPPipelineStableDiffusionXL(
286
+ DiffusionPipeline,
287
+ FromSingleFileMixin,
288
+ StableDiffusionXLLoraLoaderMixin,
289
+ TextualInversionLoaderMixin,
290
+ IPAdapterMixin,
291
+ ):
292
+ """
293
+ Pipeline for textual image editing using LEDits++ with Stable Diffusion XL.
294
+
295
+ This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the superclass
296
+ documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular
297
+ device, etc.).
298
+
299
+ In addition the pipeline inherits the following loading methods:
300
+ - *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`]
301
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
302
+
303
+ as well as the following saving methods:
304
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
305
+
306
+ Args:
307
+ vae ([`AutoencoderKL`]):
308
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
309
+ text_encoder ([`~transformers.CLIPTextModel`]):
310
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
311
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
312
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
313
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
314
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
315
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
316
+ specifically the
317
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
318
+ variant.
319
+ tokenizer ([`~transformers.CLIPTokenizer`]):
320
+ Tokenizer of class
321
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
322
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
323
+ Second Tokenizer of class
324
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
325
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
326
+ scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]):
327
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
328
+ [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically
329
+ be set to [`DPMSolverMultistepScheduler`].
330
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
331
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
332
+ `stabilityai/stable-diffusion-xl-base-1-0`.
333
+ add_watermarker (`bool`, *optional*):
334
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
335
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
336
+ watermarker will be used.
337
+ """
338
+
339
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
340
+ _optional_components = [
341
+ "tokenizer",
342
+ "tokenizer_2",
343
+ "text_encoder",
344
+ "text_encoder_2",
345
+ "image_encoder",
346
+ "feature_extractor",
347
+ ]
348
+ _callback_tensor_inputs = [
349
+ "latents",
350
+ "prompt_embeds",
351
+ "negative_prompt_embeds",
352
+ "add_text_embeds",
353
+ "add_time_ids",
354
+ "negative_pooled_prompt_embeds",
355
+ "negative_add_time_ids",
356
+ ]
357
+
358
+ def __init__(
359
+ self,
360
+ vae: AutoencoderKL,
361
+ text_encoder: CLIPTextModel,
362
+ text_encoder_2: CLIPTextModelWithProjection,
363
+ tokenizer: CLIPTokenizer,
364
+ tokenizer_2: CLIPTokenizer,
365
+ unet: UNet2DConditionModel,
366
+ scheduler: Union[DPMSolverMultistepScheduler, DDIMScheduler],
367
+ image_encoder: CLIPVisionModelWithProjection = None,
368
+ feature_extractor: CLIPImageProcessor = None,
369
+ force_zeros_for_empty_prompt: bool = True,
370
+ add_watermarker: Optional[bool] = None,
371
+ ):
372
+ super().__init__()
373
+
374
+ self.register_modules(
375
+ vae=vae,
376
+ text_encoder=text_encoder,
377
+ text_encoder_2=text_encoder_2,
378
+ tokenizer=tokenizer,
379
+ tokenizer_2=tokenizer_2,
380
+ unet=unet,
381
+ scheduler=scheduler,
382
+ image_encoder=image_encoder,
383
+ feature_extractor=feature_extractor,
384
+ )
385
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
386
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
387
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
388
+
389
+ if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler):
390
+ self.scheduler = DPMSolverMultistepScheduler.from_config(
391
+ scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2
392
+ )
393
+ logger.warning(
394
+ "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. "
395
+ "The scheduler has been changed to DPMSolverMultistepScheduler."
396
+ )
397
+
398
+ self.default_sample_size = self.unet.config.sample_size
399
+
400
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
401
+
402
+ if add_watermarker:
403
+ self.watermark = StableDiffusionXLWatermarker()
404
+ else:
405
+ self.watermark = None
406
+ self.inversion_steps = None
407
+
408
+ def encode_prompt(
409
+ self,
410
+ device: Optional[torch.device] = None,
411
+ num_images_per_prompt: int = 1,
412
+ negative_prompt: Optional[str] = None,
413
+ negative_prompt_2: Optional[str] = None,
414
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
415
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
416
+ lora_scale: Optional[float] = None,
417
+ clip_skip: Optional[int] = None,
418
+ enable_edit_guidance: bool = True,
419
+ editing_prompt: Optional[str] = None,
420
+ editing_prompt_embeds: Optional[torch.FloatTensor] = None,
421
+ editing_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
422
+ ) -> object:
423
+ r"""
424
+ Encodes the prompt into text encoder hidden states.
425
+
426
+ Args:
427
+ device: (`torch.device`):
428
+ torch device
429
+ num_images_per_prompt (`int`):
430
+ number of images that should be generated per prompt
431
+ negative_prompt (`str` or `List[str]`, *optional*):
432
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
433
+ `negative_prompt_embeds` instead.
434
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
435
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
436
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
437
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
438
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
439
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
440
+ argument.
441
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
442
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
443
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
444
+ input argument.
445
+ lora_scale (`float`, *optional*):
446
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
447
+ clip_skip (`int`, *optional*):
448
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
449
+ the output of the pre-final layer will be used for computing the prompt embeddings.
450
+ enable_edit_guidance (`bool`):
451
+ Whether to guide towards an editing prompt or not.
452
+ editing_prompt (`str` or `List[str]`, *optional*):
453
+ Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass
454
+ `editing_prompt_embeds` instead.
455
+ editing_prompt_embeds (`torch.FloatTensor`, *optional*):
456
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
457
+ weighting. If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from `editing_prompt` input
458
+ argument.
459
+ editing_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
460
+ Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
461
+ weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt`
462
+ input argument.
463
+ """
464
+ device = device or self._execution_device
465
+
466
+ # set lora scale so that monkey patched LoRA
467
+ # function of text encoder can correctly access it
468
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
469
+ self._lora_scale = lora_scale
470
+
471
+ # dynamically adjust the LoRA scale
472
+ if self.text_encoder is not None:
473
+ if not USE_PEFT_BACKEND:
474
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
475
+ else:
476
+ scale_lora_layers(self.text_encoder, lora_scale)
477
+
478
+ if self.text_encoder_2 is not None:
479
+ if not USE_PEFT_BACKEND:
480
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
481
+ else:
482
+ scale_lora_layers(self.text_encoder_2, lora_scale)
483
+
484
+ batch_size = self.batch_size
485
+
486
+ # Define tokenizers and text encoders
487
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
488
+ text_encoders = (
489
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
490
+ )
491
+ num_edit_tokens = 0
492
+
493
+ # get unconditional embeddings for classifier free guidance
494
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
495
+
496
+ if negative_prompt_embeds is None:
497
+ negative_prompt = negative_prompt or ""
498
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
499
+
500
+ # normalize str to list
501
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
502
+ negative_prompt_2 = (
503
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
504
+ )
505
+
506
+ uncond_tokens: List[str]
507
+
508
+ if batch_size != len(negative_prompt):
509
+ raise ValueError(
510
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but image inversion "
511
+ f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
512
+ " the batch size of the input images."
513
+ )
514
+ else:
515
+ uncond_tokens = [negative_prompt, negative_prompt_2]
516
+
517
+ negative_prompt_embeds_list = []
518
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
519
+ if isinstance(self, TextualInversionLoaderMixin):
520
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
521
+
522
+ uncond_input = tokenizer(
523
+ negative_prompt,
524
+ padding="max_length",
525
+ max_length=tokenizer.model_max_length,
526
+ truncation=True,
527
+ return_tensors="pt",
528
+ )
529
+
530
+ negative_prompt_embeds = text_encoder(
531
+ uncond_input.input_ids.to(device),
532
+ output_hidden_states=True,
533
+ )
534
+ # We are only ALWAYS interested in the pooled output of the final text encoder
535
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
536
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
537
+
538
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
539
+
540
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
541
+
542
+ if zero_out_negative_prompt:
543
+ negative_prompt_embeds = torch.zeros_like(negative_prompt_embeds)
544
+ negative_pooled_prompt_embeds = torch.zeros_like(negative_pooled_prompt_embeds)
545
+
546
+ if enable_edit_guidance and editing_prompt_embeds is None:
547
+ editing_prompt_2 = editing_prompt
548
+
549
+ editing_prompts = [editing_prompt, editing_prompt_2]
550
+ edit_prompt_embeds_list = []
551
+
552
+ for editing_prompt, tokenizer, text_encoder in zip(editing_prompts, tokenizers, text_encoders):
553
+ if isinstance(self, TextualInversionLoaderMixin):
554
+ editing_prompt = self.maybe_convert_prompt(editing_prompt, tokenizer)
555
+
556
+ max_length = negative_prompt_embeds.shape[1]
557
+ edit_concepts_input = tokenizer(
558
+ # [x for item in editing_prompt for x in repeat(item, batch_size)],
559
+ editing_prompt,
560
+ padding="max_length",
561
+ max_length=max_length,
562
+ truncation=True,
563
+ return_tensors="pt",
564
+ return_length=True,
565
+ )
566
+ num_edit_tokens = edit_concepts_input.length - 2
567
+
568
+ edit_concepts_embeds = text_encoder(
569
+ edit_concepts_input.input_ids.to(device),
570
+ output_hidden_states=True,
571
+ )
572
+ # We are only ALWAYS interested in the pooled output of the final text encoder
573
+ editing_pooled_prompt_embeds = edit_concepts_embeds[0]
574
+ if clip_skip is None:
575
+ edit_concepts_embeds = edit_concepts_embeds.hidden_states[-2]
576
+ else:
577
+ # "2" because SDXL always indexes from the penultimate layer.
578
+ edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
579
+
580
+ edit_prompt_embeds_list.append(edit_concepts_embeds)
581
+
582
+ edit_concepts_embeds = torch.concat(edit_prompt_embeds_list, dim=-1)
583
+ elif not enable_edit_guidance:
584
+ edit_concepts_embeds = None
585
+ editing_pooled_prompt_embeds = None
586
+
587
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
588
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
589
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
590
+ seq_len = negative_prompt_embeds.shape[1]
591
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
592
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
593
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
594
+
595
+ if enable_edit_guidance:
596
+ bs_embed_edit, seq_len, _ = edit_concepts_embeds.shape
597
+ edit_concepts_embeds = edit_concepts_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
598
+ edit_concepts_embeds = edit_concepts_embeds.repeat(1, num_images_per_prompt, 1)
599
+ edit_concepts_embeds = edit_concepts_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1)
600
+
601
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
602
+ bs_embed * num_images_per_prompt, -1
603
+ )
604
+
605
+ if enable_edit_guidance:
606
+ editing_pooled_prompt_embeds = editing_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
607
+ bs_embed_edit * num_images_per_prompt, -1
608
+ )
609
+
610
+ if self.text_encoder is not None:
611
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
612
+ # Retrieve the original scale by scaling back the LoRA layers
613
+ unscale_lora_layers(self.text_encoder, lora_scale)
614
+
615
+ if self.text_encoder_2 is not None:
616
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
617
+ # Retrieve the original scale by scaling back the LoRA layers
618
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
619
+
620
+ return (
621
+ negative_prompt_embeds,
622
+ edit_concepts_embeds,
623
+ negative_pooled_prompt_embeds,
624
+ editing_pooled_prompt_embeds,
625
+ num_edit_tokens,
626
+ )
627
+
628
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
629
+ def prepare_extra_step_kwargs(self, eta, generator=None):
630
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
631
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
632
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
633
+ # and should be between [0, 1]
634
+
635
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
636
+ extra_step_kwargs = {}
637
+ if accepts_eta:
638
+ extra_step_kwargs["eta"] = eta
639
+
640
+ # check if the scheduler accepts generator
641
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
642
+ if accepts_generator:
643
+ extra_step_kwargs["generator"] = generator
644
+ return extra_step_kwargs
645
+
646
+ def check_inputs(
647
+ self,
648
+ negative_prompt=None,
649
+ negative_prompt_2=None,
650
+ negative_prompt_embeds=None,
651
+ negative_pooled_prompt_embeds=None,
652
+ ):
653
+ if negative_prompt is not None and negative_prompt_embeds is not None:
654
+ raise ValueError(
655
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
656
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
657
+ )
658
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
659
+ raise ValueError(
660
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
661
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
662
+ )
663
+
664
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
665
+ raise ValueError(
666
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
667
+ )
668
+
669
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
670
+ def prepare_latents(self, device, latents):
671
+ latents = latents.to(device)
672
+
673
+ # scale the initial noise by the standard deviation required by the scheduler
674
+ latents = latents * self.scheduler.init_noise_sigma
675
+ return latents
676
+
677
+ def _get_add_time_ids(
678
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
679
+ ):
680
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
681
+
682
+ passed_add_embed_dim = (
683
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
684
+ )
685
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
686
+
687
+ if expected_add_embed_dim != passed_add_embed_dim:
688
+ raise ValueError(
689
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
690
+ )
691
+
692
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
693
+ return add_time_ids
694
+
695
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
696
+ def upcast_vae(self):
697
+ dtype = self.vae.dtype
698
+ self.vae.to(dtype=torch.float32)
699
+ use_torch_2_0_or_xformers = isinstance(
700
+ self.vae.decoder.mid_block.attentions[0].processor,
701
+ (
702
+ AttnProcessor2_0,
703
+ XFormersAttnProcessor,
704
+ LoRAXFormersAttnProcessor,
705
+ LoRAAttnProcessor2_0,
706
+ ),
707
+ )
708
+ # if xformers or torch_2_0 is used attention block does not need
709
+ # to be in float32 which can save lots of memory
710
+ if use_torch_2_0_or_xformers:
711
+ self.vae.post_quant_conv.to(dtype)
712
+ self.vae.decoder.conv_in.to(dtype)
713
+ self.vae.decoder.mid_block.to(dtype)
714
+
715
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
716
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
717
+ """
718
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
719
+
720
+ Args:
721
+ timesteps (`torch.Tensor`):
722
+ generate embedding vectors at these timesteps
723
+ embedding_dim (`int`, *optional*, defaults to 512):
724
+ dimension of the embeddings to generate
725
+ dtype:
726
+ data type of the generated embeddings
727
+
728
+ Returns:
729
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
730
+ """
731
+ assert len(w.shape) == 1
732
+ w = w * 1000.0
733
+
734
+ half_dim = embedding_dim // 2
735
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
736
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
737
+ emb = w.to(dtype)[:, None] * emb[None, :]
738
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
739
+ if embedding_dim % 2 == 1: # zero pad
740
+ emb = torch.nn.functional.pad(emb, (0, 1))
741
+ assert emb.shape == (w.shape[0], embedding_dim)
742
+ return emb
743
+
744
+ @property
745
+ def guidance_scale(self):
746
+ return self._guidance_scale
747
+
748
+ @property
749
+ def guidance_rescale(self):
750
+ return self._guidance_rescale
751
+
752
+ @property
753
+ def clip_skip(self):
754
+ return self._clip_skip
755
+
756
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
757
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
758
+ # corresponds to doing no classifier free guidance.
759
+ @property
760
+ def do_classifier_free_guidance(self):
761
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
762
+
763
+ @property
764
+ def cross_attention_kwargs(self):
765
+ return self._cross_attention_kwargs
766
+
767
+ @property
768
+ def denoising_end(self):
769
+ return self._denoising_end
770
+
771
+ @property
772
+ def num_timesteps(self):
773
+ return self._num_timesteps
774
+
775
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
776
+ def prepare_unet(self, attention_store, PnP: bool = False):
777
+ attn_procs = {}
778
+ for name in self.unet.attn_processors.keys():
779
+ if name.startswith("mid_block"):
780
+ place_in_unet = "mid"
781
+ elif name.startswith("up_blocks"):
782
+ place_in_unet = "up"
783
+ elif name.startswith("down_blocks"):
784
+ place_in_unet = "down"
785
+ else:
786
+ continue
787
+
788
+ if "attn2" in name and place_in_unet != "mid":
789
+ attn_procs[name] = LEDITSCrossAttnProcessor(
790
+ attention_store=attention_store,
791
+ place_in_unet=place_in_unet,
792
+ pnp=PnP,
793
+ editing_prompts=self.enabled_editing_prompts,
794
+ )
795
+ else:
796
+ attn_procs[name] = AttnProcessor()
797
+
798
+ self.unet.set_attn_processor(attn_procs)
799
+
800
+ @torch.no_grad()
801
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
802
+ def __call__(
803
+ self,
804
+ denoising_end: Optional[float] = None,
805
+ negative_prompt: Optional[Union[str, List[str]]] = None,
806
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
807
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
808
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
809
+ ip_adapter_image: Optional[PipelineImageInput] = None,
810
+ output_type: Optional[str] = "pil",
811
+ return_dict: bool = True,
812
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
813
+ guidance_rescale: float = 0.0,
814
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
815
+ target_size: Optional[Tuple[int, int]] = None,
816
+ editing_prompt: Optional[Union[str, List[str]]] = None,
817
+ editing_prompt_embeddings: Optional[torch.Tensor] = None,
818
+ editing_pooled_prompt_embeds: Optional[torch.Tensor] = None,
819
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
820
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
821
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
822
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
823
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
824
+ sem_guidance: Optional[List[torch.Tensor]] = None,
825
+ use_cross_attn_mask: bool = False,
826
+ use_intersect_mask: bool = False,
827
+ user_mask: Optional[torch.FloatTensor] = None,
828
+ attn_store_steps: Optional[List[int]] = [],
829
+ store_averaged_over_steps: bool = True,
830
+ clip_skip: Optional[int] = None,
831
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
832
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
833
+ **kwargs,
834
+ ):
835
+ r"""
836
+ The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`]
837
+ method has to be called beforehand. Edits will always be performed for the last inverted image(s).
838
+
839
+ Args:
840
+ denoising_end (`float`, *optional*):
841
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
842
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
843
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
844
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
845
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
846
+ negative_prompt (`str` or `List[str]`, *optional*):
847
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
848
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
849
+ less than `1`).
850
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
851
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
852
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
853
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
854
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
855
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
856
+ argument.
857
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
858
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
859
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
860
+ input argument.
861
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
862
+ Optional image input to work with IP Adapters.
863
+ output_type (`str`, *optional*, defaults to `"pil"`):
864
+ The output format of the generate image. Choose between
865
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
866
+ return_dict (`bool`, *optional*, defaults to `True`):
867
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
868
+ of a plain tuple.
869
+ callback (`Callable`, *optional*):
870
+ A function that will be called every `callback_steps` steps during inference. The function will be
871
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
872
+ callback_steps (`int`, *optional*, defaults to 1):
873
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
874
+ called at every step.
875
+ cross_attention_kwargs (`dict`, *optional*):
876
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
877
+ `self.processor` in
878
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
879
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
880
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
881
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
882
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
883
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
884
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
885
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
886
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
887
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
888
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
889
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
890
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
891
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
892
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
893
+ editing_prompt (`str` or `List[str]`, *optional*):
894
+ The prompt or prompts to guide the image generation. The image is reconstructed by setting
895
+ `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`.
896
+ editing_prompt_embeddings (`torch.Tensor`, *optional*):
897
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
898
+ weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input
899
+ argument.
900
+ editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*):
901
+ Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
902
+ weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input
903
+ argument.
904
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
905
+ Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
906
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
907
+ Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`.
908
+ `edit_guidance_scale` is defined as `s_e` of equation 12 of
909
+ [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
910
+ edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
911
+ Number of diffusion steps (for each prompt) for which guidance is not applied.
912
+ edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
913
+ Number of diffusion steps (for each prompt) after which guidance is no longer applied.
914
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
915
+ Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
916
+ 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
917
+ sem_guidance (`List[torch.Tensor]`, *optional*):
918
+ List of pre-generated guidance vectors to be applied at generation. Length of the list has to
919
+ correspond to `num_inference_steps`.
920
+ use_cross_attn_mask:
921
+ Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
922
+ is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of
923
+ [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
924
+ use_intersect_mask:
925
+ Whether the masking term is calculated as intersection of cross-attention masks and masks derived
926
+ from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise
927
+ estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
928
+ user_mask:
929
+ User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit
930
+ masks do not meet user preferences.
931
+ attn_store_steps:
932
+ Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
933
+ store_averaged_over_steps:
934
+ Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps.
935
+ If False, attention maps for each step are stores separately. Just for visualization purposes.
936
+ clip_skip (`int`, *optional*):
937
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
938
+ the output of the pre-final layer will be used for computing the prompt embeddings.
939
+ callback_on_step_end (`Callable`, *optional*):
940
+ A function that calls at the end of each denoising steps during the inference. The function is called
941
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
942
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
943
+ `callback_on_step_end_tensor_inputs`.
944
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
945
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
946
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
947
+ `._callback_tensor_inputs` attribute of your pipeline class.
948
+
949
+ Examples:
950
+
951
+ Returns:
952
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`:
953
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True,
954
+ otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
955
+ """
956
+ if self.inversion_steps is None:
957
+ raise ValueError(
958
+ "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)."
959
+ )
960
+
961
+ eta = self.eta
962
+ num_images_per_prompt = 1
963
+ latents = self.init_latents
964
+
965
+ zs = self.zs
966
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
967
+
968
+ if use_intersect_mask:
969
+ use_cross_attn_mask = True
970
+
971
+ if use_cross_attn_mask:
972
+ self.smoothing = LeditsGaussianSmoothing(self.device)
973
+
974
+ if user_mask is not None:
975
+ user_mask = user_mask.to(self.device)
976
+
977
+ # TODO: Check inputs
978
+ # 1. Check inputs. Raise error if not correct
979
+ # self.check_inputs(
980
+ # callback_steps,
981
+ # negative_prompt,
982
+ # negative_prompt_2,
983
+ # prompt_embeds,
984
+ # negative_prompt_embeds,
985
+ # pooled_prompt_embeds,
986
+ # negative_pooled_prompt_embeds,
987
+ # )
988
+ self._guidance_rescale = guidance_rescale
989
+ self._clip_skip = clip_skip
990
+ self._cross_attention_kwargs = cross_attention_kwargs
991
+ self._denoising_end = denoising_end
992
+
993
+ # 2. Define call parameters
994
+ batch_size = self.batch_size
995
+
996
+ device = self._execution_device
997
+
998
+ if editing_prompt:
999
+ enable_edit_guidance = True
1000
+ if isinstance(editing_prompt, str):
1001
+ editing_prompt = [editing_prompt]
1002
+ self.enabled_editing_prompts = len(editing_prompt)
1003
+ elif editing_prompt_embeddings is not None:
1004
+ enable_edit_guidance = True
1005
+ self.enabled_editing_prompts = editing_prompt_embeddings.shape[0]
1006
+ else:
1007
+ self.enabled_editing_prompts = 0
1008
+ enable_edit_guidance = False
1009
+
1010
+ # 3. Encode input prompt
1011
+ text_encoder_lora_scale = (
1012
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1013
+ )
1014
+ (
1015
+ prompt_embeds,
1016
+ edit_prompt_embeds,
1017
+ negative_pooled_prompt_embeds,
1018
+ pooled_edit_embeds,
1019
+ num_edit_tokens,
1020
+ ) = self.encode_prompt(
1021
+ device=device,
1022
+ num_images_per_prompt=num_images_per_prompt,
1023
+ negative_prompt=negative_prompt,
1024
+ negative_prompt_2=negative_prompt_2,
1025
+ negative_prompt_embeds=negative_prompt_embeds,
1026
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1027
+ lora_scale=text_encoder_lora_scale,
1028
+ clip_skip=self.clip_skip,
1029
+ enable_edit_guidance=enable_edit_guidance,
1030
+ editing_prompt=editing_prompt,
1031
+ editing_prompt_embeds=editing_prompt_embeddings,
1032
+ editing_pooled_prompt_embeds=editing_pooled_prompt_embeds,
1033
+ )
1034
+
1035
+ # 4. Prepare timesteps
1036
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
1037
+
1038
+ timesteps = self.inversion_steps
1039
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1040
+
1041
+ if use_cross_attn_mask:
1042
+ self.attention_store = LeditsAttentionStore(
1043
+ average=store_averaged_over_steps,
1044
+ batch_size=batch_size,
1045
+ max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0),
1046
+ max_resolution=None,
1047
+ )
1048
+ self.prepare_unet(self.attention_store)
1049
+ resolution = latents.shape[-2:]
1050
+ att_res = (int(resolution[0] / 4), int(resolution[1] / 4))
1051
+
1052
+ # 5. Prepare latent variables
1053
+ latents = self.prepare_latents(device=device, latents=latents)
1054
+
1055
+ # 6. Prepare extra step kwargs.
1056
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
1057
+
1058
+ if self.text_encoder_2 is None:
1059
+ text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1])
1060
+ else:
1061
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1062
+
1063
+ # 7. Prepare added time ids & embeddings
1064
+ add_text_embeds = negative_pooled_prompt_embeds
1065
+ add_time_ids = self._get_add_time_ids(
1066
+ self.size,
1067
+ crops_coords_top_left,
1068
+ self.size,
1069
+ dtype=negative_pooled_prompt_embeds.dtype,
1070
+ text_encoder_projection_dim=text_encoder_projection_dim,
1071
+ )
1072
+
1073
+ if enable_edit_guidance:
1074
+ prompt_embeds = torch.cat([prompt_embeds, edit_prompt_embeds], dim=0)
1075
+ add_text_embeds = torch.cat([add_text_embeds, pooled_edit_embeds], dim=0)
1076
+ edit_concepts_time_ids = add_time_ids.repeat(edit_prompt_embeds.shape[0], 1)
1077
+ add_time_ids = torch.cat([add_time_ids, edit_concepts_time_ids], dim=0)
1078
+ self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt
1079
+
1080
+ prompt_embeds = prompt_embeds.to(device)
1081
+ add_text_embeds = add_text_embeds.to(device)
1082
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1083
+
1084
+ if ip_adapter_image is not None:
1085
+ # TODO: fix image encoding
1086
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1087
+ if self.do_classifier_free_guidance:
1088
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
1089
+ image_embeds = image_embeds.to(device)
1090
+
1091
+ # 8. Denoising loop
1092
+ self.sem_guidance = None
1093
+ self.activation_mask = None
1094
+
1095
+ if (
1096
+ self.denoising_end is not None
1097
+ and isinstance(self.denoising_end, float)
1098
+ and self.denoising_end > 0
1099
+ and self.denoising_end < 1
1100
+ ):
1101
+ discrete_timestep_cutoff = int(
1102
+ round(
1103
+ self.scheduler.config.num_train_timesteps
1104
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1105
+ )
1106
+ )
1107
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1108
+ timesteps = timesteps[:num_inference_steps]
1109
+
1110
+ # 9. Optionally get Guidance Scale Embedding
1111
+ timestep_cond = None
1112
+ if self.unet.config.time_cond_proj_dim is not None:
1113
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1114
+ timestep_cond = self.get_guidance_scale_embedding(
1115
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1116
+ ).to(device=device, dtype=latents.dtype)
1117
+
1118
+ self._num_timesteps = len(timesteps)
1119
+ with self.progress_bar(total=self._num_timesteps) as progress_bar:
1120
+ for i, t in enumerate(timesteps):
1121
+ # expand the latents if we are doing classifier free guidance
1122
+ latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts))
1123
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1124
+ # predict the noise residual
1125
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1126
+ if ip_adapter_image is not None:
1127
+ added_cond_kwargs["image_embeds"] = image_embeds
1128
+ noise_pred = self.unet(
1129
+ latent_model_input,
1130
+ t,
1131
+ encoder_hidden_states=prompt_embeds,
1132
+ cross_attention_kwargs=cross_attention_kwargs,
1133
+ added_cond_kwargs=added_cond_kwargs,
1134
+ return_dict=False,
1135
+ )[0]
1136
+
1137
+ noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64]
1138
+ noise_pred_uncond = noise_pred_out[0]
1139
+ noise_pred_edit_concepts = noise_pred_out[1:]
1140
+
1141
+ noise_guidance_edit = torch.zeros(
1142
+ noise_pred_uncond.shape,
1143
+ device=self.device,
1144
+ dtype=noise_pred_uncond.dtype,
1145
+ )
1146
+
1147
+ if sem_guidance is not None and len(sem_guidance) > i:
1148
+ noise_guidance_edit += sem_guidance[i].to(self.device)
1149
+
1150
+ elif enable_edit_guidance:
1151
+ if self.activation_mask is None:
1152
+ self.activation_mask = torch.zeros(
1153
+ (len(timesteps), self.enabled_editing_prompts, *noise_pred_edit_concepts[0].shape)
1154
+ )
1155
+ if self.sem_guidance is None:
1156
+ self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
1157
+
1158
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
1159
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
1160
+ if isinstance(edit_warmup_steps, list):
1161
+ edit_warmup_steps_c = edit_warmup_steps[c]
1162
+ else:
1163
+ edit_warmup_steps_c = edit_warmup_steps
1164
+ if i < edit_warmup_steps_c:
1165
+ continue
1166
+
1167
+ if isinstance(edit_guidance_scale, list):
1168
+ edit_guidance_scale_c = edit_guidance_scale[c]
1169
+ else:
1170
+ edit_guidance_scale_c = edit_guidance_scale
1171
+
1172
+ if isinstance(edit_threshold, list):
1173
+ edit_threshold_c = edit_threshold[c]
1174
+ else:
1175
+ edit_threshold_c = edit_threshold
1176
+ if isinstance(reverse_editing_direction, list):
1177
+ reverse_editing_direction_c = reverse_editing_direction[c]
1178
+ else:
1179
+ reverse_editing_direction_c = reverse_editing_direction
1180
+
1181
+ if isinstance(edit_cooldown_steps, list):
1182
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
1183
+ elif edit_cooldown_steps is None:
1184
+ edit_cooldown_steps_c = i + 1
1185
+ else:
1186
+ edit_cooldown_steps_c = edit_cooldown_steps
1187
+
1188
+ if i >= edit_cooldown_steps_c:
1189
+ continue
1190
+
1191
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
1192
+
1193
+ if reverse_editing_direction_c:
1194
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
1195
+
1196
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
1197
+
1198
+ if user_mask is not None:
1199
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
1200
+
1201
+ if use_cross_attn_mask:
1202
+ out = self.attention_store.aggregate_attention(
1203
+ attention_maps=self.attention_store.step_store,
1204
+ prompts=self.text_cross_attention_maps,
1205
+ res=att_res,
1206
+ from_where=["up", "down"],
1207
+ is_cross=True,
1208
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
1209
+ )
1210
+ attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext
1211
+
1212
+ # average over all tokens
1213
+ if attn_map.shape[3] != num_edit_tokens[c]:
1214
+ raise ValueError(
1215
+ f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!"
1216
+ )
1217
+ attn_map = torch.sum(attn_map, dim=3)
1218
+
1219
+ # gaussian_smoothing
1220
+ attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
1221
+ attn_map = self.smoothing(attn_map).squeeze(1)
1222
+
1223
+ # torch.quantile function expects float32
1224
+ if attn_map.dtype == torch.float32:
1225
+ tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
1226
+ else:
1227
+ tmp = torch.quantile(
1228
+ attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1
1229
+ ).to(attn_map.dtype)
1230
+ attn_mask = torch.where(
1231
+ attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0
1232
+ )
1233
+
1234
+ # resolution must match latent space dimension
1235
+ attn_mask = F.interpolate(
1236
+ attn_mask.unsqueeze(1),
1237
+ noise_guidance_edit_tmp.shape[-2:], # 64,64
1238
+ ).repeat(1, 4, 1, 1)
1239
+ self.activation_mask[i, c] = attn_mask.detach().cpu()
1240
+ if not use_intersect_mask:
1241
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
1242
+
1243
+ if use_intersect_mask:
1244
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1245
+ noise_guidance_edit_tmp_quantile = torch.sum(
1246
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1247
+ )
1248
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(
1249
+ 1, self.unet.config.in_channels, 1, 1
1250
+ )
1251
+
1252
+ # torch.quantile function expects float32
1253
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1254
+ tmp = torch.quantile(
1255
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1256
+ edit_threshold_c,
1257
+ dim=2,
1258
+ keepdim=False,
1259
+ )
1260
+ else:
1261
+ tmp = torch.quantile(
1262
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1263
+ edit_threshold_c,
1264
+ dim=2,
1265
+ keepdim=False,
1266
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1267
+
1268
+ intersect_mask = (
1269
+ torch.where(
1270
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1271
+ torch.ones_like(noise_guidance_edit_tmp),
1272
+ torch.zeros_like(noise_guidance_edit_tmp),
1273
+ )
1274
+ * attn_mask
1275
+ )
1276
+
1277
+ self.activation_mask[i, c] = intersect_mask.detach().cpu()
1278
+
1279
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
1280
+
1281
+ elif not use_cross_attn_mask:
1282
+ # calculate quantile
1283
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1284
+ noise_guidance_edit_tmp_quantile = torch.sum(
1285
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1286
+ )
1287
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
1288
+
1289
+ # torch.quantile function expects float32
1290
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1291
+ tmp = torch.quantile(
1292
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1293
+ edit_threshold_c,
1294
+ dim=2,
1295
+ keepdim=False,
1296
+ )
1297
+ else:
1298
+ tmp = torch.quantile(
1299
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1300
+ edit_threshold_c,
1301
+ dim=2,
1302
+ keepdim=False,
1303
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1304
+
1305
+ self.activation_mask[i, c] = (
1306
+ torch.where(
1307
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1308
+ torch.ones_like(noise_guidance_edit_tmp),
1309
+ torch.zeros_like(noise_guidance_edit_tmp),
1310
+ )
1311
+ .detach()
1312
+ .cpu()
1313
+ )
1314
+
1315
+ noise_guidance_edit_tmp = torch.where(
1316
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1317
+ noise_guidance_edit_tmp,
1318
+ torch.zeros_like(noise_guidance_edit_tmp),
1319
+ )
1320
+
1321
+ noise_guidance_edit += noise_guidance_edit_tmp
1322
+
1323
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1324
+
1325
+ noise_pred = noise_pred_uncond + noise_guidance_edit
1326
+
1327
+ # compute the previous noisy sample x_t -> x_t-1
1328
+ if enable_edit_guidance and self.guidance_rescale > 0.0:
1329
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1330
+ noise_pred = rescale_noise_cfg(
1331
+ noise_pred,
1332
+ noise_pred_edit_concepts.mean(dim=0, keepdim=False),
1333
+ guidance_rescale=self.guidance_rescale,
1334
+ )
1335
+
1336
+ idx = t_to_idx[int(t)]
1337
+ latents = self.scheduler.step(
1338
+ noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs, return_dict=False
1339
+ )[0]
1340
+
1341
+ # step callback
1342
+ if use_cross_attn_mask:
1343
+ store_step = i in attn_store_steps
1344
+ self.attention_store.between_steps(store_step)
1345
+
1346
+ if callback_on_step_end is not None:
1347
+ callback_kwargs = {}
1348
+ for k in callback_on_step_end_tensor_inputs:
1349
+ callback_kwargs[k] = locals()[k]
1350
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1351
+
1352
+ latents = callback_outputs.pop("latents", latents)
1353
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1354
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1355
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1356
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1357
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1358
+ )
1359
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1360
+ # negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1361
+
1362
+ # call the callback, if provided
1363
+ if i == len(timesteps) - 1 or ((i + 1) > 0 and (i + 1) % self.scheduler.order == 0):
1364
+ progress_bar.update()
1365
+
1366
+ if XLA_AVAILABLE:
1367
+ xm.mark_step()
1368
+
1369
+ if not output_type == "latent":
1370
+ # make sure the VAE is in float32 mode, as it overflows in float16
1371
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1372
+
1373
+ if needs_upcasting:
1374
+ self.upcast_vae()
1375
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1376
+
1377
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1378
+
1379
+ # cast back to fp16 if needed
1380
+ if needs_upcasting:
1381
+ self.vae.to(dtype=torch.float16)
1382
+ else:
1383
+ image = latents
1384
+
1385
+ if not output_type == "latent":
1386
+ # apply watermark if available
1387
+ if self.watermark is not None:
1388
+ image = self.watermark.apply_watermark(image)
1389
+
1390
+ image = self.image_processor.postprocess(image, output_type=output_type)
1391
+
1392
+ # Offload all models
1393
+ self.maybe_free_model_hooks()
1394
+
1395
+ if not return_dict:
1396
+ return (image,)
1397
+
1398
+ return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
1399
+
1400
+ @torch.no_grad()
1401
+ # Modified from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.encode_image
1402
+ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
1403
+ image = self.image_processor.preprocess(
1404
+ image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1405
+ )
1406
+ resized = self.image_processor.postprocess(image=image, output_type="pil")
1407
+
1408
+ if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
1409
+ logger.warning(
1410
+ "Your input images far exceed the default resolution of the underlying diffusion model. "
1411
+ "The output images may contain severe artifacts! "
1412
+ "Consider down-sampling the input using the `height` and `width` parameters"
1413
+ )
1414
+ image = image.to(self.device, dtype=dtype)
1415
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1416
+
1417
+ if needs_upcasting:
1418
+ image = image.float()
1419
+ self.upcast_vae()
1420
+ image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1421
+
1422
+ x0 = self.vae.encode(image).latent_dist.mode()
1423
+ x0 = x0.to(dtype)
1424
+ # cast back to fp16 if needed
1425
+ if needs_upcasting:
1426
+ self.vae.to(dtype=torch.float16)
1427
+
1428
+ x0 = self.vae.config.scaling_factor * x0
1429
+ return x0, resized
1430
+
1431
+ @torch.no_grad()
1432
+ def invert(
1433
+ self,
1434
+ image: PipelineImageInput,
1435
+ source_prompt: str = "",
1436
+ source_guidance_scale=3.5,
1437
+ negative_prompt: str = None,
1438
+ negative_prompt_2: str = None,
1439
+ num_inversion_steps: int = 50,
1440
+ skip: float = 0.15,
1441
+ generator: Optional[torch.Generator] = None,
1442
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1443
+ num_zero_noise_steps: int = 3,
1444
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1445
+ ):
1446
+ r"""
1447
+ The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
1448
+ If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140)
1449
+ will be performed instead.
1450
+
1451
+ Args:
1452
+ image (`PipelineImageInput`):
1453
+ Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect
1454
+ ratio.
1455
+ source_prompt (`str`, defaults to `""`):
1456
+ Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled
1457
+ if the `source_prompt` is `""`.
1458
+ source_guidance_scale (`float`, defaults to `3.5`):
1459
+ Strength of guidance during inversion.
1460
+ negative_prompt (`str` or `List[str]`, *optional*):
1461
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1462
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1463
+ less than `1`).
1464
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1465
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1466
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1467
+ num_inversion_steps (`int`, defaults to `50`):
1468
+ Number of total performed inversion steps after discarding the initial `skip` steps.
1469
+ skip (`float`, defaults to `0.15`):
1470
+ Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values
1471
+ will lead to stronger changes to the input image. `skip` has to be between `0` and `1`.
1472
+ generator (`torch.Generator`, *optional*):
1473
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1474
+ inversion deterministic.
1475
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1476
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1477
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1478
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1479
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1480
+ num_zero_noise_steps (`int`, defaults to `3`):
1481
+ Number of final diffusion steps that will not renoise the current image. If no steps are set to zero
1482
+ SD-XL in combination with [`DPMSolverMultistepScheduler`] will produce noise artifacts.
1483
+ cross_attention_kwargs (`dict`, *optional*):
1484
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1485
+ `self.processor` in
1486
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1487
+
1488
+ Returns:
1489
+ [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]:
1490
+ Output will contain the resized input image(s) and respective VAE reconstruction(s).
1491
+ """
1492
+
1493
+ # Reset attn processor, we do not want to store attn maps during inversion
1494
+ self.unet.set_attn_processor(AttnProcessor())
1495
+
1496
+ self.eta = 1.0
1497
+
1498
+ self.scheduler.config.timestep_spacing = "leading"
1499
+ self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip)))
1500
+ self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:]
1501
+ timesteps = self.inversion_steps
1502
+
1503
+ num_images_per_prompt = 1
1504
+
1505
+ device = self._execution_device
1506
+
1507
+ # 0. Ensure that only uncond embedding is used if prompt = ""
1508
+ if source_prompt == "":
1509
+ # noise pred should only be noise_pred_uncond
1510
+ source_guidance_scale = 0.0
1511
+ do_classifier_free_guidance = False
1512
+ else:
1513
+ do_classifier_free_guidance = source_guidance_scale > 1.0
1514
+
1515
+ # 1. prepare image
1516
+ x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
1517
+ width = x0.shape[2] * self.vae_scale_factor
1518
+ height = x0.shape[3] * self.vae_scale_factor
1519
+ self.size = (height, width)
1520
+
1521
+ self.batch_size = x0.shape[0]
1522
+
1523
+ # 2. get embeddings
1524
+ text_encoder_lora_scale = (
1525
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1526
+ )
1527
+
1528
+ if isinstance(source_prompt, str):
1529
+ source_prompt = [source_prompt] * self.batch_size
1530
+
1531
+ (
1532
+ negative_prompt_embeds,
1533
+ prompt_embeds,
1534
+ negative_pooled_prompt_embeds,
1535
+ edit_pooled_prompt_embeds,
1536
+ _,
1537
+ ) = self.encode_prompt(
1538
+ device=device,
1539
+ num_images_per_prompt=num_images_per_prompt,
1540
+ negative_prompt=negative_prompt,
1541
+ negative_prompt_2=negative_prompt_2,
1542
+ editing_prompt=source_prompt,
1543
+ lora_scale=text_encoder_lora_scale,
1544
+ enable_edit_guidance=do_classifier_free_guidance,
1545
+ )
1546
+ if self.text_encoder_2 is None:
1547
+ text_encoder_projection_dim = int(negative_pooled_prompt_embeds.shape[-1])
1548
+ else:
1549
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1550
+
1551
+ # 3. Prepare added time ids & embeddings
1552
+ add_text_embeds = negative_pooled_prompt_embeds
1553
+ add_time_ids = self._get_add_time_ids(
1554
+ self.size,
1555
+ crops_coords_top_left,
1556
+ self.size,
1557
+ dtype=negative_prompt_embeds.dtype,
1558
+ text_encoder_projection_dim=text_encoder_projection_dim,
1559
+ )
1560
+
1561
+ if do_classifier_free_guidance:
1562
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1563
+ add_text_embeds = torch.cat([add_text_embeds, edit_pooled_prompt_embeds], dim=0)
1564
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
1565
+
1566
+ negative_prompt_embeds = negative_prompt_embeds.to(device)
1567
+
1568
+ add_text_embeds = add_text_embeds.to(device)
1569
+ add_time_ids = add_time_ids.to(device).repeat(self.batch_size * num_images_per_prompt, 1)
1570
+
1571
+ # autoencoder reconstruction
1572
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1573
+ self.upcast_vae()
1574
+ x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1575
+ image_rec = self.vae.decode(
1576
+ x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator
1577
+ )[0]
1578
+ elif self.vae.config.force_upcast:
1579
+ x0_tmp = x0.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1580
+ image_rec = self.vae.decode(
1581
+ x0_tmp / self.vae.config.scaling_factor, return_dict=False, generator=generator
1582
+ )[0]
1583
+ else:
1584
+ image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
1585
+
1586
+ image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1587
+
1588
+ # 5. find zs and xts
1589
+ variance_noise_shape = (num_inversion_steps, *x0.shape)
1590
+
1591
+ # intermediate latents
1592
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1593
+ xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1594
+
1595
+ for t in reversed(timesteps):
1596
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1597
+ noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1598
+ xts[idx] = self.scheduler.add_noise(x0, noise, t.unsqueeze(0))
1599
+ xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1600
+
1601
+ # noise maps
1602
+ zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=negative_prompt_embeds.dtype)
1603
+
1604
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1605
+
1606
+ for t in self.progress_bar(timesteps):
1607
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1608
+ # 1. predict noise residual
1609
+ xt = xts[idx + 1]
1610
+
1611
+ latent_model_input = torch.cat([xt] * 2) if do_classifier_free_guidance else xt
1612
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1613
+
1614
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1615
+
1616
+ noise_pred = self.unet(
1617
+ latent_model_input,
1618
+ t,
1619
+ encoder_hidden_states=negative_prompt_embeds,
1620
+ cross_attention_kwargs=cross_attention_kwargs,
1621
+ added_cond_kwargs=added_cond_kwargs,
1622
+ return_dict=False,
1623
+ )[0]
1624
+
1625
+ # 2. perform guidance
1626
+ if do_classifier_free_guidance:
1627
+ noise_pred_out = noise_pred.chunk(2)
1628
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
1629
+ noise_pred = noise_pred_uncond + source_guidance_scale * (noise_pred_text - noise_pred_uncond)
1630
+
1631
+ xtm1 = xts[idx]
1632
+ z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta)
1633
+ zs[idx] = z
1634
+
1635
+ # correction to avoid error accumulation
1636
+ xts[idx] = xtm1_corrected
1637
+
1638
+ self.init_latents = xts[-1]
1639
+ zs = zs.flip(0)
1640
+
1641
+ if num_zero_noise_steps > 0:
1642
+ zs[-num_zero_noise_steps:] = torch.zeros_like(zs[-num_zero_noise_steps:])
1643
+ self.zs = zs
1644
+ return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1645
+
1646
+
1647
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
1648
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
1649
+ """
1650
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
1651
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
1652
+ """
1653
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
1654
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
1655
+ # rescale the results from guidance (fixes overexposure)
1656
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
1657
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
1658
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
1659
+ return noise_cfg
1660
+
1661
+
1662
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_ddim
1663
+ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1664
+ # 1. get previous step value (=t-1)
1665
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
1666
+
1667
+ # 2. compute alphas, betas
1668
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
1669
+ alpha_prod_t_prev = (
1670
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
1671
+ )
1672
+
1673
+ beta_prod_t = 1 - alpha_prod_t
1674
+
1675
+ # 3. compute predicted original sample from predicted noise also called
1676
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1677
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
1678
+
1679
+ # 4. Clip "predicted x_0"
1680
+ if scheduler.config.clip_sample:
1681
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
1682
+
1683
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
1684
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
1685
+ variance = scheduler._get_variance(timestep, prev_timestep)
1686
+ std_dev_t = eta * variance ** (0.5)
1687
+
1688
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1689
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
1690
+
1691
+ # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
1692
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
1693
+ if variance > 0.0:
1694
+ noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
1695
+ else:
1696
+ noise = torch.tensor([0.0]).to(latents.device)
1697
+
1698
+ return noise, mu_xt + (eta * variance**0.5) * noise
1699
+
1700
+
1701
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise_sde_dpm_pp_2nd
1702
+ def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1703
+ def first_order_update(model_output, sample): # timestep, prev_timestep, sample):
1704
+ sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index]
1705
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1706
+ alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s)
1707
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1708
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
1709
+
1710
+ h = lambda_t - lambda_s
1711
+
1712
+ mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
1713
+
1714
+ mu_xt = scheduler.dpm_solver_first_order_update(
1715
+ model_output=model_output, sample=sample, noise=torch.zeros_like(sample)
1716
+ )
1717
+
1718
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1719
+ if sigma > 0.0:
1720
+ noise = (prev_latents - mu_xt) / sigma
1721
+ else:
1722
+ noise = torch.tensor([0.0]).to(sample.device)
1723
+
1724
+ prev_sample = mu_xt + sigma * noise
1725
+ return noise, prev_sample
1726
+
1727
+ def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample):
1728
+ sigma_t, sigma_s0, sigma_s1 = (
1729
+ scheduler.sigmas[scheduler.step_index + 1],
1730
+ scheduler.sigmas[scheduler.step_index],
1731
+ scheduler.sigmas[scheduler.step_index - 1],
1732
+ )
1733
+
1734
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1735
+ alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0)
1736
+ alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1)
1737
+
1738
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1739
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
1740
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
1741
+
1742
+ m0, m1 = model_output_list[-1], model_output_list[-2]
1743
+
1744
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
1745
+ r0 = h_0 / h
1746
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
1747
+
1748
+ mu_xt = (
1749
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
1750
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
1751
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
1752
+ )
1753
+
1754
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1755
+ if sigma > 0.0:
1756
+ noise = (prev_latents - mu_xt) / sigma
1757
+ else:
1758
+ noise = torch.tensor([0.0]).to(sample.device)
1759
+
1760
+ prev_sample = mu_xt + sigma * noise
1761
+
1762
+ return noise, prev_sample
1763
+
1764
+ if scheduler.step_index is None:
1765
+ scheduler._init_step_index(timestep)
1766
+
1767
+ model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents)
1768
+ for i in range(scheduler.config.solver_order - 1):
1769
+ scheduler.model_outputs[i] = scheduler.model_outputs[i + 1]
1770
+ scheduler.model_outputs[-1] = model_output
1771
+
1772
+ if scheduler.lower_order_nums < 1:
1773
+ noise, prev_sample = first_order_update(model_output, latents)
1774
+ else:
1775
+ noise, prev_sample = second_order_update(scheduler.model_outputs, latents)
1776
+
1777
+ if scheduler.lower_order_nums < scheduler.config.solver_order:
1778
+ scheduler.lower_order_nums += 1
1779
+
1780
+ # upon completion increase step index by one
1781
+ scheduler._step_index += 1
1782
+
1783
+ return noise, prev_sample
1784
+
1785
+
1786
+ # Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.compute_noise
1787
+ def compute_noise(scheduler, *args):
1788
+ if isinstance(scheduler, DDIMScheduler):
1789
+ return compute_noise_ddim(scheduler, *args)
1790
+ elif (
1791
+ isinstance(scheduler, DPMSolverMultistepScheduler)
1792
+ and scheduler.config.algorithm_type == "sde-dpmsolver++"
1793
+ and scheduler.config.solver_order == 2
1794
+ ):
1795
+ return compute_noise_sde_dpm_pp_2nd(scheduler, *args)
1796
+ else:
1797
+ raise NotImplementedError