diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1505 @@
1
+ import inspect
2
+ import math
3
+ from itertools import repeat
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from packaging import version
9
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
10
+
11
+ from ...configuration_utils import FrozenDict
12
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
13
+ from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
14
+ from ...models import AutoencoderKL, UNet2DConditionModel
15
+ from ...models.attention_processor import Attention, AttnProcessor
16
+ from ...models.lora import adjust_lora_scale_text_encoder
17
+ from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
18
+ from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
19
+ from ...utils import (
20
+ USE_PEFT_BACKEND,
21
+ deprecate,
22
+ logging,
23
+ replace_example_docstring,
24
+ scale_lora_layers,
25
+ unscale_lora_layers,
26
+ )
27
+ from ...utils.torch_utils import randn_tensor
28
+ from ..pipeline_utils import DiffusionPipeline
29
+ from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+ EXAMPLE_DOC_STRING = """
35
+ Examples:
36
+ ```py
37
+ >>> import PIL
38
+ >>> import requests
39
+ >>> import torch
40
+ >>> from io import BytesIO
41
+
42
+ >>> from diffusers import LEditsPPPipelineStableDiffusion
43
+
44
+ >>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
45
+ ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
46
+ ... )
47
+ >>> pipe = pipe.to("cuda")
48
+
49
+ >>> def download_image(url):
50
+ ... response = requests.get(url)
51
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
52
+
53
+ >>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
54
+ >>> image = download_image(img_url)
55
+
56
+ >>> _ = pipe.invert(
57
+ ... image = image,
58
+ ... num_inversion_steps=50,
59
+ ... skip=0.1
60
+ ... )
61
+
62
+ >>> edited_image = pipe(
63
+ ... editing_prompt=["cherry blossom"],
64
+ ... edit_guidance_scale=10.0,
65
+ ... edit_threshold=0.75,
66
+ ).images[0]
67
+ ```
68
+ """
69
+
70
+
71
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.AttentionStore
72
+ class LeditsAttentionStore:
73
+ @staticmethod
74
+ def get_empty_store():
75
+ return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []}
76
+
77
+ def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
78
+ # attn.shape = batch_size * head_size, seq_len query, seq_len_key
79
+ if attn.shape[1] <= self.max_size:
80
+ bs = 1 + int(PnP) + editing_prompts
81
+ skip = 2 if PnP else 1 # skip PnP & unconditional
82
+ attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
83
+ source_batch_size = int(attn.shape[1] // bs)
84
+ self.forward(attn[:, skip * source_batch_size :], is_cross, place_in_unet)
85
+
86
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
87
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
88
+
89
+ self.step_store[key].append(attn)
90
+
91
+ def between_steps(self, store_step=True):
92
+ if store_step:
93
+ if self.average:
94
+ if len(self.attention_store) == 0:
95
+ self.attention_store = self.step_store
96
+ else:
97
+ for key in self.attention_store:
98
+ for i in range(len(self.attention_store[key])):
99
+ self.attention_store[key][i] += self.step_store[key][i]
100
+ else:
101
+ if len(self.attention_store) == 0:
102
+ self.attention_store = [self.step_store]
103
+ else:
104
+ self.attention_store.append(self.step_store)
105
+
106
+ self.cur_step += 1
107
+ self.step_store = self.get_empty_store()
108
+
109
+ def get_attention(self, step: int):
110
+ if self.average:
111
+ attention = {
112
+ key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store
113
+ }
114
+ else:
115
+ assert step is not None
116
+ attention = self.attention_store[step]
117
+ return attention
118
+
119
+ def aggregate_attention(
120
+ self, attention_maps, prompts, res: Union[int, Tuple[int]], from_where: List[str], is_cross: bool, select: int
121
+ ):
122
+ out = [[] for x in range(self.batch_size)]
123
+ if isinstance(res, int):
124
+ num_pixels = res**2
125
+ resolution = (res, res)
126
+ else:
127
+ num_pixels = res[0] * res[1]
128
+ resolution = res[:2]
129
+
130
+ for location in from_where:
131
+ for bs_item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
132
+ for batch, item in enumerate(bs_item):
133
+ if item.shape[1] == num_pixels:
134
+ cross_maps = item.reshape(len(prompts), -1, *resolution, item.shape[-1])[select]
135
+ out[batch].append(cross_maps)
136
+
137
+ out = torch.stack([torch.cat(x, dim=0) for x in out])
138
+ # average over heads
139
+ out = out.sum(1) / out.shape[1]
140
+ return out
141
+
142
+ def __init__(self, average: bool, batch_size=1, max_resolution=16, max_size: int = None):
143
+ self.step_store = self.get_empty_store()
144
+ self.attention_store = []
145
+ self.cur_step = 0
146
+ self.average = average
147
+ self.batch_size = batch_size
148
+ if max_size is None:
149
+ self.max_size = max_resolution**2
150
+ elif max_size is not None and max_resolution is None:
151
+ self.max_size = max_size
152
+ else:
153
+ raise ValueError("Only allowed to set one of max_resolution or max_size")
154
+
155
+
156
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionAttendAndExcitePipeline.GaussianSmoothing
157
+ class LeditsGaussianSmoothing:
158
+ def __init__(self, device):
159
+ kernel_size = [3, 3]
160
+ sigma = [0.5, 0.5]
161
+
162
+ # The gaussian kernel is the product of the gaussian function of each dimension.
163
+ kernel = 1
164
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
165
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
166
+ mean = (size - 1) / 2
167
+ kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
168
+
169
+ # Make sure sum of values in gaussian kernel equals 1.
170
+ kernel = kernel / torch.sum(kernel)
171
+
172
+ # Reshape to depthwise convolutional weight
173
+ kernel = kernel.view(1, 1, *kernel.size())
174
+ kernel = kernel.repeat(1, *[1] * (kernel.dim() - 1))
175
+
176
+ self.weight = kernel.to(device)
177
+
178
+ def __call__(self, input):
179
+ """
180
+ Arguments:
181
+ Apply gaussian filter to input.
182
+ input (torch.Tensor): Input to apply gaussian filter on.
183
+ Returns:
184
+ filtered (torch.Tensor): Filtered output.
185
+ """
186
+ return F.conv2d(input, weight=self.weight.to(input.dtype))
187
+
188
+
189
+ class LEDITSCrossAttnProcessor:
190
+ def __init__(self, attention_store, place_in_unet, pnp, editing_prompts):
191
+ self.attnstore = attention_store
192
+ self.place_in_unet = place_in_unet
193
+ self.editing_prompts = editing_prompts
194
+ self.pnp = pnp
195
+
196
+ def __call__(
197
+ self,
198
+ attn: Attention,
199
+ hidden_states,
200
+ encoder_hidden_states,
201
+ attention_mask=None,
202
+ temb=None,
203
+ ):
204
+ batch_size, sequence_length, _ = (
205
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
206
+ )
207
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
208
+
209
+ query = attn.to_q(hidden_states)
210
+
211
+ if encoder_hidden_states is None:
212
+ encoder_hidden_states = hidden_states
213
+ elif attn.norm_cross:
214
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
215
+
216
+ key = attn.to_k(encoder_hidden_states)
217
+ value = attn.to_v(encoder_hidden_states)
218
+
219
+ query = attn.head_to_batch_dim(query)
220
+ key = attn.head_to_batch_dim(key)
221
+ value = attn.head_to_batch_dim(value)
222
+
223
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
224
+ self.attnstore(
225
+ attention_probs,
226
+ is_cross=True,
227
+ place_in_unet=self.place_in_unet,
228
+ editing_prompts=self.editing_prompts,
229
+ PnP=self.pnp,
230
+ )
231
+
232
+ hidden_states = torch.bmm(attention_probs, value)
233
+ hidden_states = attn.batch_to_head_dim(hidden_states)
234
+
235
+ # linear proj
236
+ hidden_states = attn.to_out[0](hidden_states)
237
+ # dropout
238
+ hidden_states = attn.to_out[1](hidden_states)
239
+
240
+ hidden_states = hidden_states / attn.rescale_output_factor
241
+ return hidden_states
242
+
243
+
244
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
245
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
246
+ """
247
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
248
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
249
+ """
250
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
251
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
252
+ # rescale the results from guidance (fixes overexposure)
253
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
254
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
255
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
256
+ return noise_cfg
257
+
258
+
259
+ class LEditsPPPipelineStableDiffusion(
260
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
261
+ ):
262
+ """
263
+ Pipeline for textual image editing using LEDits++ with Stable Diffusion.
264
+
265
+ This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionPipeline`]. Check the superclass
266
+ documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular
267
+ device, etc.).
268
+
269
+ Args:
270
+ vae ([`AutoencoderKL`]):
271
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
272
+ text_encoder ([`~transformers.CLIPTextModel`]):
273
+ Frozen text-encoder. Stable Diffusion uses the text portion of
274
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
275
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
276
+ tokenizer ([`~transformers.CLIPTokenizer`]):
277
+ Tokenizer of class
278
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
279
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
280
+ scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]):
281
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
282
+ [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically
283
+ be set to [`DPMSolverMultistepScheduler`].
284
+ safety_checker ([`StableDiffusionSafetyChecker`]):
285
+ Classification module that estimates whether generated images could be considered offensive or harmful.
286
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
287
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
288
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
289
+ """
290
+
291
+ model_cpu_offload_seq = "text_encoder->unet->vae"
292
+ _exclude_from_cpu_offload = ["safety_checker"]
293
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
294
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
295
+
296
+ def __init__(
297
+ self,
298
+ vae: AutoencoderKL,
299
+ text_encoder: CLIPTextModel,
300
+ tokenizer: CLIPTokenizer,
301
+ unet: UNet2DConditionModel,
302
+ scheduler: Union[DDIMScheduler, DPMSolverMultistepScheduler],
303
+ safety_checker: StableDiffusionSafetyChecker,
304
+ feature_extractor: CLIPImageProcessor,
305
+ requires_safety_checker: bool = True,
306
+ ):
307
+ super().__init__()
308
+
309
+ if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler):
310
+ scheduler = DPMSolverMultistepScheduler.from_config(
311
+ scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2
312
+ )
313
+ logger.warning(
314
+ "This pipeline only supports DDIMScheduler and DPMSolverMultistepScheduler. "
315
+ "The scheduler has been changed to DPMSolverMultistepScheduler."
316
+ )
317
+
318
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
319
+ deprecation_message = (
320
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
321
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
322
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
323
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
324
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
325
+ " file"
326
+ )
327
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
328
+ new_config = dict(scheduler.config)
329
+ new_config["steps_offset"] = 1
330
+ scheduler._internal_dict = FrozenDict(new_config)
331
+
332
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
333
+ deprecation_message = (
334
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
335
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
336
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
337
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
338
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
339
+ )
340
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
341
+ new_config = dict(scheduler.config)
342
+ new_config["clip_sample"] = False
343
+ scheduler._internal_dict = FrozenDict(new_config)
344
+
345
+ if safety_checker is None and requires_safety_checker:
346
+ logger.warning(
347
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
348
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
349
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
350
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
351
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
352
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
353
+ )
354
+
355
+ if safety_checker is not None and feature_extractor is None:
356
+ raise ValueError(
357
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
358
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
359
+ )
360
+
361
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
362
+ version.parse(unet.config._diffusers_version).base_version
363
+ ) < version.parse("0.9.0.dev0")
364
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
365
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
366
+ deprecation_message = (
367
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
368
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
369
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
370
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
371
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
372
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
373
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
374
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
375
+ " the `unet/config.json` file"
376
+ )
377
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
378
+ new_config = dict(unet.config)
379
+ new_config["sample_size"] = 64
380
+ unet._internal_dict = FrozenDict(new_config)
381
+
382
+ self.register_modules(
383
+ vae=vae,
384
+ text_encoder=text_encoder,
385
+ tokenizer=tokenizer,
386
+ unet=unet,
387
+ scheduler=scheduler,
388
+ safety_checker=safety_checker,
389
+ feature_extractor=feature_extractor,
390
+ )
391
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
392
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
393
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
394
+
395
+ self.inversion_steps = None
396
+
397
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
398
+ def run_safety_checker(self, image, device, dtype):
399
+ if self.safety_checker is None:
400
+ has_nsfw_concept = None
401
+ else:
402
+ if torch.is_tensor(image):
403
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
404
+ else:
405
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
406
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
407
+ image, has_nsfw_concept = self.safety_checker(
408
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
409
+ )
410
+ return image, has_nsfw_concept
411
+
412
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
413
+ def decode_latents(self, latents):
414
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
415
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
416
+
417
+ latents = 1 / self.vae.config.scaling_factor * latents
418
+ image = self.vae.decode(latents, return_dict=False)[0]
419
+ image = (image / 2 + 0.5).clamp(0, 1)
420
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
421
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
422
+ return image
423
+
424
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
425
+ def prepare_extra_step_kwargs(self, eta, generator=None):
426
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
427
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
428
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
429
+ # and should be between [0, 1]
430
+
431
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
432
+ extra_step_kwargs = {}
433
+ if accepts_eta:
434
+ extra_step_kwargs["eta"] = eta
435
+
436
+ # check if the scheduler accepts generator
437
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
438
+ if accepts_generator:
439
+ extra_step_kwargs["generator"] = generator
440
+ return extra_step_kwargs
441
+
442
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
443
+ def check_inputs(
444
+ self,
445
+ negative_prompt=None,
446
+ editing_prompt_embeddings=None,
447
+ negative_prompt_embeds=None,
448
+ callback_on_step_end_tensor_inputs=None,
449
+ ):
450
+ if callback_on_step_end_tensor_inputs is not None and not all(
451
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
452
+ ):
453
+ raise ValueError(
454
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
455
+ )
456
+ if negative_prompt is not None and negative_prompt_embeds is not None:
457
+ raise ValueError(
458
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
459
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
460
+ )
461
+
462
+ if editing_prompt_embeddings is not None and negative_prompt_embeds is not None:
463
+ if editing_prompt_embeddings.shape != negative_prompt_embeds.shape:
464
+ raise ValueError(
465
+ "`editing_prompt_embeddings` and `negative_prompt_embeds` must have the same shape when passed directly, but"
466
+ f" got: `editing_prompt_embeddings` {editing_prompt_embeddings.shape} != `negative_prompt_embeds`"
467
+ f" {negative_prompt_embeds.shape}."
468
+ )
469
+
470
+ # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
471
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
472
+ # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
473
+
474
+ # if latents.shape != shape:
475
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
476
+
477
+ latents = latents.to(device)
478
+
479
+ # scale the initial noise by the standard deviation required by the scheduler
480
+ latents = latents * self.scheduler.init_noise_sigma
481
+ return latents
482
+
483
+ def prepare_unet(self, attention_store, PnP: bool = False):
484
+ attn_procs = {}
485
+ for name in self.unet.attn_processors.keys():
486
+ if name.startswith("mid_block"):
487
+ place_in_unet = "mid"
488
+ elif name.startswith("up_blocks"):
489
+ place_in_unet = "up"
490
+ elif name.startswith("down_blocks"):
491
+ place_in_unet = "down"
492
+ else:
493
+ continue
494
+
495
+ if "attn2" in name and place_in_unet != "mid":
496
+ attn_procs[name] = LEDITSCrossAttnProcessor(
497
+ attention_store=attention_store,
498
+ place_in_unet=place_in_unet,
499
+ pnp=PnP,
500
+ editing_prompts=self.enabled_editing_prompts,
501
+ )
502
+ else:
503
+ attn_procs[name] = AttnProcessor()
504
+
505
+ self.unet.set_attn_processor(attn_procs)
506
+
507
+ def encode_prompt(
508
+ self,
509
+ device,
510
+ num_images_per_prompt,
511
+ enable_edit_guidance,
512
+ negative_prompt=None,
513
+ editing_prompt=None,
514
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
515
+ editing_prompt_embeds: Optional[torch.FloatTensor] = None,
516
+ lora_scale: Optional[float] = None,
517
+ clip_skip: Optional[int] = None,
518
+ ):
519
+ r"""
520
+ Encodes the prompt into text encoder hidden states.
521
+
522
+ Args:
523
+ device: (`torch.device`):
524
+ torch device
525
+ num_images_per_prompt (`int`):
526
+ number of images that should be generated per prompt
527
+ enable_edit_guidance (`bool`):
528
+ whether to perform any editing or reconstruct the input image instead
529
+ negative_prompt (`str` or `List[str]`, *optional*):
530
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
531
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
532
+ less than `1`).
533
+ editing_prompt (`str` or `List[str]`, *optional*):
534
+ Editing prompt(s) to be encoded. If not defined, one has to pass
535
+ `editing_prompt_embeds` instead.
536
+ editing_prompt_embeds (`torch.FloatTensor`, *optional*):
537
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
538
+ provided, text embeddings will be generated from `prompt` input argument.
539
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
540
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
541
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
542
+ argument.
543
+ lora_scale (`float`, *optional*):
544
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
545
+ clip_skip (`int`, *optional*):
546
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
547
+ the output of the pre-final layer will be used for computing the prompt embeddings.
548
+ """
549
+ # set lora scale so that monkey patched LoRA
550
+ # function of text encoder can correctly access it
551
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
552
+ self._lora_scale = lora_scale
553
+
554
+ # dynamically adjust the LoRA scale
555
+ if not USE_PEFT_BACKEND:
556
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
557
+ else:
558
+ scale_lora_layers(self.text_encoder, lora_scale)
559
+
560
+ batch_size = self.batch_size
561
+ num_edit_tokens = None
562
+
563
+ if negative_prompt_embeds is None:
564
+ uncond_tokens: List[str]
565
+ if negative_prompt is None:
566
+ uncond_tokens = [""] * batch_size
567
+ elif isinstance(negative_prompt, str):
568
+ uncond_tokens = [negative_prompt]
569
+ elif batch_size != len(negative_prompt):
570
+ raise ValueError(
571
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but exoected"
572
+ f"{batch_size} based on the input images. Please make sure that passed `negative_prompt` matches"
573
+ " the batch size of `prompt`."
574
+ )
575
+ else:
576
+ uncond_tokens = negative_prompt
577
+
578
+ # textual inversion: procecss multi-vector tokens if necessary
579
+ if isinstance(self, TextualInversionLoaderMixin):
580
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
581
+
582
+ uncond_input = self.tokenizer(
583
+ uncond_tokens,
584
+ padding="max_length",
585
+ max_length=self.tokenizer.model_max_length,
586
+ truncation=True,
587
+ return_tensors="pt",
588
+ )
589
+
590
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
591
+ attention_mask = uncond_input.attention_mask.to(device)
592
+ else:
593
+ attention_mask = None
594
+
595
+ negative_prompt_embeds = self.text_encoder(
596
+ uncond_input.input_ids.to(device),
597
+ attention_mask=attention_mask,
598
+ )
599
+ negative_prompt_embeds = negative_prompt_embeds[0]
600
+
601
+ if self.text_encoder is not None:
602
+ prompt_embeds_dtype = self.text_encoder.dtype
603
+ elif self.unet is not None:
604
+ prompt_embeds_dtype = self.unet.dtype
605
+ else:
606
+ prompt_embeds_dtype = negative_prompt_embeds.dtype
607
+
608
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
609
+
610
+ if enable_edit_guidance:
611
+ if editing_prompt_embeds is None:
612
+ # textual inversion: procecss multi-vector tokens if necessary
613
+ # if isinstance(self, TextualInversionLoaderMixin):
614
+ # prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
615
+ if isinstance(editing_prompt, str):
616
+ editing_prompt = [editing_prompt]
617
+
618
+ max_length = negative_prompt_embeds.shape[1]
619
+ text_inputs = self.tokenizer(
620
+ [x for item in editing_prompt for x in repeat(item, batch_size)],
621
+ padding="max_length",
622
+ max_length=max_length,
623
+ truncation=True,
624
+ return_tensors="pt",
625
+ return_length=True,
626
+ )
627
+
628
+ num_edit_tokens = text_inputs.length - 2 # not counting startoftext and endoftext
629
+ text_input_ids = text_inputs.input_ids
630
+ untruncated_ids = self.tokenizer(
631
+ [x for item in editing_prompt for x in repeat(item, batch_size)],
632
+ padding="longest",
633
+ return_tensors="pt",
634
+ ).input_ids
635
+
636
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
637
+ text_input_ids, untruncated_ids
638
+ ):
639
+ removed_text = self.tokenizer.batch_decode(
640
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
641
+ )
642
+ logger.warning(
643
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
644
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
645
+ )
646
+
647
+ if (
648
+ hasattr(self.text_encoder.config, "use_attention_mask")
649
+ and self.text_encoder.config.use_attention_mask
650
+ ):
651
+ attention_mask = text_inputs.attention_mask.to(device)
652
+ else:
653
+ attention_mask = None
654
+
655
+ if clip_skip is None:
656
+ editing_prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
657
+ editing_prompt_embeds = editing_prompt_embeds[0]
658
+ else:
659
+ editing_prompt_embeds = self.text_encoder(
660
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
661
+ )
662
+ # Access the `hidden_states` first, that contains a tuple of
663
+ # all the hidden states from the encoder layers. Then index into
664
+ # the tuple to access the hidden states from the desired layer.
665
+ editing_prompt_embeds = editing_prompt_embeds[-1][-(clip_skip + 1)]
666
+ # We also need to apply the final LayerNorm here to not mess with the
667
+ # representations. The `last_hidden_states` that we typically use for
668
+ # obtaining the final prompt representations passes through the LayerNorm
669
+ # layer.
670
+ editing_prompt_embeds = self.text_encoder.text_model.final_layer_norm(editing_prompt_embeds)
671
+
672
+ editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device)
673
+
674
+ bs_embed_edit, seq_len, _ = editing_prompt_embeds.shape
675
+ editing_prompt_embeds = editing_prompt_embeds.to(dtype=negative_prompt_embeds.dtype, device=device)
676
+ editing_prompt_embeds = editing_prompt_embeds.repeat(1, num_images_per_prompt, 1)
677
+ editing_prompt_embeds = editing_prompt_embeds.view(bs_embed_edit * num_images_per_prompt, seq_len, -1)
678
+
679
+ # get unconditional embeddings for classifier free guidance
680
+
681
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
682
+ seq_len = negative_prompt_embeds.shape[1]
683
+
684
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
685
+
686
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
687
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
688
+
689
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
690
+ # Retrieve the original scale by scaling back the LoRA layers
691
+ unscale_lora_layers(self.text_encoder, lora_scale)
692
+
693
+ return editing_prompt_embeds, negative_prompt_embeds, num_edit_tokens
694
+
695
+ @property
696
+ def guidance_rescale(self):
697
+ return self._guidance_rescale
698
+
699
+ @property
700
+ def clip_skip(self):
701
+ return self._clip_skip
702
+
703
+ @property
704
+ def cross_attention_kwargs(self):
705
+ return self._cross_attention_kwargs
706
+
707
+ @torch.no_grad()
708
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
709
+ def __call__(
710
+ self,
711
+ negative_prompt: Optional[Union[str, List[str]]] = None,
712
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
713
+ output_type: Optional[str] = "pil",
714
+ return_dict: bool = True,
715
+ editing_prompt: Optional[Union[str, List[str]]] = None,
716
+ editing_prompt_embeds: Optional[torch.Tensor] = None,
717
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
718
+ reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
719
+ edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
720
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
721
+ edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
722
+ edit_threshold: Optional[Union[float, List[float]]] = 0.9,
723
+ user_mask: Optional[torch.FloatTensor] = None,
724
+ sem_guidance: Optional[List[torch.Tensor]] = None,
725
+ use_cross_attn_mask: bool = False,
726
+ use_intersect_mask: bool = True,
727
+ attn_store_steps: Optional[List[int]] = [],
728
+ store_averaged_over_steps: bool = True,
729
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
730
+ guidance_rescale: float = 0.0,
731
+ clip_skip: Optional[int] = None,
732
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
733
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
734
+ **kwargs,
735
+ ):
736
+ r"""
737
+ The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`]
738
+ method has to be called beforehand. Edits will always be performed for the last inverted image(s).
739
+
740
+ Args:
741
+ negative_prompt (`str` or `List[str]`, *optional*):
742
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
743
+ if `guidance_scale` is less than `1`).
744
+ generator (`torch.Generator`, *optional*):
745
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
746
+ to make generation deterministic.
747
+ output_type (`str`, *optional*, defaults to `"pil"`):
748
+ The output format of the generate image. Choose between
749
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
750
+ return_dict (`bool`, *optional*, defaults to `True`):
751
+ Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a
752
+ plain tuple.
753
+ editing_prompt (`str` or `List[str]`, *optional*):
754
+ The prompt or prompts to guide the image generation. The image is reconstructed by setting
755
+ `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`.
756
+ editing_prompt_embeds (`torch.Tensor>`, *optional*):
757
+ Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should be
758
+ specified via `reverse_editing_direction`.
759
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
760
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
761
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
762
+ reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
763
+ Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
764
+ edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
765
+ Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`.
766
+ `edit_guidance_scale` is defined as `s_e` of equation 12 of
767
+ [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
768
+ edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
769
+ Number of diffusion steps (for each prompt) for which guidance will not be applied.
770
+ edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
771
+ Number of diffusion steps (for each prompt) after which guidance will no longer be applied.
772
+ edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
773
+ Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
774
+ 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
775
+ user_mask (`torch.FloatTensor`, *optional*):
776
+ User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit
777
+ masks do not meet user preferences.
778
+ sem_guidance (`List[torch.Tensor]`, *optional*):
779
+ List of pre-generated guidance vectors to be applied at generation. Length of the list has to
780
+ correspond to `num_inference_steps`.
781
+ use_cross_attn_mask (`bool`, defaults to `False`):
782
+ Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
783
+ is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of
784
+ [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
785
+ use_intersect_mask (`bool`, defaults to `True`):
786
+ Whether the masking term is calculated as intersection of cross-attention masks and masks derived
787
+ from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise
788
+ estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
789
+ attn_store_steps (`List[int]`, *optional*):
790
+ Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
791
+ store_averaged_over_steps (`bool`, defaults to `True`):
792
+ Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps.
793
+ If False, attention maps for each step are stores separately. Just for visualization purposes.
794
+ cross_attention_kwargs (`dict`, *optional*):
795
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
796
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
797
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
798
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
799
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
800
+ using zero terminal SNR.
801
+ clip_skip (`int`, *optional*):
802
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
803
+ the output of the pre-final layer will be used for computing the prompt embeddings.
804
+ callback_on_step_end (`Callable`, *optional*):
805
+ A function that calls at the end of each denoising steps during the inference. The function is called
806
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
807
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
808
+ `callback_on_step_end_tensor_inputs`.
809
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
810
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
811
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
812
+ `._callback_tensor_inputs` attribute of your pipeline class.
813
+
814
+ Examples:
815
+
816
+ Returns:
817
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`:
818
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True,
819
+ otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the
820
+ second element is a list of `bool`s denoting whether the corresponding generated image likely represents
821
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
822
+ """
823
+
824
+ if self.inversion_steps is None:
825
+ raise ValueError(
826
+ "You need to invert an input image first before calling the pipeline. The `invert` method has to be called beforehand. Edits will always be performed for the last inverted image(s)."
827
+ )
828
+
829
+ eta = self.eta
830
+ num_images_per_prompt = 1
831
+ latents = self.init_latents
832
+
833
+ zs = self.zs
834
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
835
+
836
+ if use_intersect_mask:
837
+ use_cross_attn_mask = True
838
+
839
+ if use_cross_attn_mask:
840
+ self.smoothing = LeditsGaussianSmoothing(self.device)
841
+
842
+ if user_mask is not None:
843
+ user_mask = user_mask.to(self.device)
844
+
845
+ org_prompt = ""
846
+
847
+ # 1. Check inputs. Raise error if not correct
848
+ self.check_inputs(
849
+ negative_prompt,
850
+ editing_prompt_embeds,
851
+ negative_prompt_embeds,
852
+ callback_on_step_end_tensor_inputs,
853
+ )
854
+
855
+ self._guidance_rescale = guidance_rescale
856
+ self._clip_skip = clip_skip
857
+ self._cross_attention_kwargs = cross_attention_kwargs
858
+
859
+ # 2. Define call parameters
860
+ batch_size = self.batch_size
861
+
862
+ if editing_prompt:
863
+ enable_edit_guidance = True
864
+ if isinstance(editing_prompt, str):
865
+ editing_prompt = [editing_prompt]
866
+ self.enabled_editing_prompts = len(editing_prompt)
867
+ elif editing_prompt_embeds is not None:
868
+ enable_edit_guidance = True
869
+ self.enabled_editing_prompts = editing_prompt_embeds.shape[0]
870
+ else:
871
+ self.enabled_editing_prompts = 0
872
+ enable_edit_guidance = False
873
+
874
+ # 3. Encode input prompt
875
+ lora_scale = (
876
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
877
+ )
878
+
879
+ edit_concepts, uncond_embeddings, num_edit_tokens = self.encode_prompt(
880
+ editing_prompt=editing_prompt,
881
+ device=self.device,
882
+ num_images_per_prompt=num_images_per_prompt,
883
+ enable_edit_guidance=enable_edit_guidance,
884
+ negative_prompt=negative_prompt,
885
+ editing_prompt_embeds=editing_prompt_embeds,
886
+ negative_prompt_embeds=negative_prompt_embeds,
887
+ lora_scale=lora_scale,
888
+ clip_skip=self.clip_skip,
889
+ )
890
+
891
+ # For classifier free guidance, we need to do two forward passes.
892
+ # Here we concatenate the unconditional and text embeddings into a single batch
893
+ # to avoid doing two forward passes
894
+ if enable_edit_guidance:
895
+ text_embeddings = torch.cat([uncond_embeddings, edit_concepts])
896
+ self.text_cross_attention_maps = [editing_prompt] if isinstance(editing_prompt, str) else editing_prompt
897
+ else:
898
+ text_embeddings = torch.cat([uncond_embeddings])
899
+
900
+ # 4. Prepare timesteps
901
+ # self.scheduler.set_timesteps(num_inference_steps, device=self.device)
902
+ timesteps = self.inversion_steps
903
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0] :])}
904
+
905
+ if use_cross_attn_mask:
906
+ self.attention_store = LeditsAttentionStore(
907
+ average=store_averaged_over_steps,
908
+ batch_size=batch_size,
909
+ max_size=(latents.shape[-2] / 4.0) * (latents.shape[-1] / 4.0),
910
+ max_resolution=None,
911
+ )
912
+ self.prepare_unet(self.attention_store, PnP=False)
913
+ resolution = latents.shape[-2:]
914
+ att_res = (int(resolution[0] / 4), int(resolution[1] / 4))
915
+
916
+ # 5. Prepare latent variables
917
+ num_channels_latents = self.unet.config.in_channels
918
+ latents = self.prepare_latents(
919
+ batch_size * num_images_per_prompt,
920
+ num_channels_latents,
921
+ None,
922
+ None,
923
+ text_embeddings.dtype,
924
+ self.device,
925
+ latents,
926
+ )
927
+
928
+ # 6. Prepare extra step kwargs.
929
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
930
+
931
+ self.sem_guidance = None
932
+ self.activation_mask = None
933
+
934
+ # 7. Denoising loop
935
+ num_warmup_steps = 0
936
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
937
+ for i, t in enumerate(timesteps):
938
+ # expand the latents if we are doing classifier free guidance
939
+
940
+ if enable_edit_guidance:
941
+ latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts))
942
+ else:
943
+ latent_model_input = latents
944
+
945
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
946
+
947
+ text_embed_input = text_embeddings
948
+
949
+ # predict the noise residual
950
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample
951
+
952
+ noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64]
953
+ noise_pred_uncond = noise_pred_out[0]
954
+ noise_pred_edit_concepts = noise_pred_out[1:]
955
+
956
+ noise_guidance_edit = torch.zeros(
957
+ noise_pred_uncond.shape,
958
+ device=self.device,
959
+ dtype=noise_pred_uncond.dtype,
960
+ )
961
+
962
+ if sem_guidance is not None and len(sem_guidance) > i:
963
+ noise_guidance_edit += sem_guidance[i].to(self.device)
964
+
965
+ elif enable_edit_guidance:
966
+ if self.activation_mask is None:
967
+ self.activation_mask = torch.zeros(
968
+ (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
969
+ )
970
+
971
+ if self.sem_guidance is None:
972
+ self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
973
+
974
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
975
+ if isinstance(edit_warmup_steps, list):
976
+ edit_warmup_steps_c = edit_warmup_steps[c]
977
+ else:
978
+ edit_warmup_steps_c = edit_warmup_steps
979
+ if i < edit_warmup_steps_c:
980
+ continue
981
+
982
+ if isinstance(edit_guidance_scale, list):
983
+ edit_guidance_scale_c = edit_guidance_scale[c]
984
+ else:
985
+ edit_guidance_scale_c = edit_guidance_scale
986
+
987
+ if isinstance(edit_threshold, list):
988
+ edit_threshold_c = edit_threshold[c]
989
+ else:
990
+ edit_threshold_c = edit_threshold
991
+ if isinstance(reverse_editing_direction, list):
992
+ reverse_editing_direction_c = reverse_editing_direction[c]
993
+ else:
994
+ reverse_editing_direction_c = reverse_editing_direction
995
+
996
+ if isinstance(edit_cooldown_steps, list):
997
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
998
+ elif edit_cooldown_steps is None:
999
+ edit_cooldown_steps_c = i + 1
1000
+ else:
1001
+ edit_cooldown_steps_c = edit_cooldown_steps
1002
+
1003
+ if i >= edit_cooldown_steps_c:
1004
+ continue
1005
+
1006
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
1007
+
1008
+ if reverse_editing_direction_c:
1009
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
1010
+
1011
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
1012
+
1013
+ if user_mask is not None:
1014
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
1015
+
1016
+ if use_cross_attn_mask:
1017
+ out = self.attention_store.aggregate_attention(
1018
+ attention_maps=self.attention_store.step_store,
1019
+ prompts=self.text_cross_attention_maps,
1020
+ res=att_res,
1021
+ from_where=["up", "down"],
1022
+ is_cross=True,
1023
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
1024
+ )
1025
+ attn_map = out[:, :, :, 1 : 1 + num_edit_tokens[c]] # 0 -> startoftext
1026
+
1027
+ # average over all tokens
1028
+ if attn_map.shape[3] != num_edit_tokens[c]:
1029
+ raise ValueError(
1030
+ f"Incorrect shape of attention_map. Expected size {num_edit_tokens[c]}, but found {attn_map.shape[3]}!"
1031
+ )
1032
+
1033
+ attn_map = torch.sum(attn_map, dim=3)
1034
+
1035
+ # gaussian_smoothing
1036
+ attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
1037
+ attn_map = self.smoothing(attn_map).squeeze(1)
1038
+
1039
+ # torch.quantile function expects float32
1040
+ if attn_map.dtype == torch.float32:
1041
+ tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
1042
+ else:
1043
+ tmp = torch.quantile(
1044
+ attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1
1045
+ ).to(attn_map.dtype)
1046
+ attn_mask = torch.where(
1047
+ attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1, *att_res), 1.0, 0.0
1048
+ )
1049
+
1050
+ # resolution must match latent space dimension
1051
+ attn_mask = F.interpolate(
1052
+ attn_mask.unsqueeze(1),
1053
+ noise_guidance_edit_tmp.shape[-2:], # 64,64
1054
+ ).repeat(1, 4, 1, 1)
1055
+ self.activation_mask[i, c] = attn_mask.detach().cpu()
1056
+ if not use_intersect_mask:
1057
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
1058
+
1059
+ if use_intersect_mask:
1060
+ if t <= 800:
1061
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1062
+ noise_guidance_edit_tmp_quantile = torch.sum(
1063
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1064
+ )
1065
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(
1066
+ 1, self.unet.config.in_channels, 1, 1
1067
+ )
1068
+
1069
+ # torch.quantile function expects float32
1070
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1071
+ tmp = torch.quantile(
1072
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1073
+ edit_threshold_c,
1074
+ dim=2,
1075
+ keepdim=False,
1076
+ )
1077
+ else:
1078
+ tmp = torch.quantile(
1079
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1080
+ edit_threshold_c,
1081
+ dim=2,
1082
+ keepdim=False,
1083
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1084
+
1085
+ intersect_mask = (
1086
+ torch.where(
1087
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1088
+ torch.ones_like(noise_guidance_edit_tmp),
1089
+ torch.zeros_like(noise_guidance_edit_tmp),
1090
+ )
1091
+ * attn_mask
1092
+ )
1093
+
1094
+ self.activation_mask[i, c] = intersect_mask.detach().cpu()
1095
+
1096
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
1097
+
1098
+ else:
1099
+ # print(f"only attention mask for step {i}")
1100
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
1101
+
1102
+ elif not use_cross_attn_mask:
1103
+ # calculate quantile
1104
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
1105
+ noise_guidance_edit_tmp_quantile = torch.sum(
1106
+ noise_guidance_edit_tmp_quantile, dim=1, keepdim=True
1107
+ )
1108
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
1109
+
1110
+ # torch.quantile function expects float32
1111
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
1112
+ tmp = torch.quantile(
1113
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
1114
+ edit_threshold_c,
1115
+ dim=2,
1116
+ keepdim=False,
1117
+ )
1118
+ else:
1119
+ tmp = torch.quantile(
1120
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1121
+ edit_threshold_c,
1122
+ dim=2,
1123
+ keepdim=False,
1124
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1125
+
1126
+ self.activation_mask[i, c] = (
1127
+ torch.where(
1128
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1129
+ torch.ones_like(noise_guidance_edit_tmp),
1130
+ torch.zeros_like(noise_guidance_edit_tmp),
1131
+ )
1132
+ .detach()
1133
+ .cpu()
1134
+ )
1135
+
1136
+ noise_guidance_edit_tmp = torch.where(
1137
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1138
+ noise_guidance_edit_tmp,
1139
+ torch.zeros_like(noise_guidance_edit_tmp),
1140
+ )
1141
+
1142
+ noise_guidance_edit += noise_guidance_edit_tmp
1143
+
1144
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1145
+
1146
+ noise_pred = noise_pred_uncond + noise_guidance_edit
1147
+
1148
+ if enable_edit_guidance and self.guidance_rescale > 0.0:
1149
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1150
+ noise_pred = rescale_noise_cfg(
1151
+ noise_pred,
1152
+ noise_pred_edit_concepts.mean(dim=0, keepdim=False),
1153
+ guidance_rescale=self.guidance_rescale,
1154
+ )
1155
+
1156
+ idx = t_to_idx[int(t)]
1157
+ latents = self.scheduler.step(
1158
+ noise_pred, t, latents, variance_noise=zs[idx], **extra_step_kwargs
1159
+ ).prev_sample
1160
+
1161
+ # step callback
1162
+ if use_cross_attn_mask:
1163
+ store_step = i in attn_store_steps
1164
+ self.attention_store.between_steps(store_step)
1165
+
1166
+ if callback_on_step_end is not None:
1167
+ callback_kwargs = {}
1168
+ for k in callback_on_step_end_tensor_inputs:
1169
+ callback_kwargs[k] = locals()[k]
1170
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1171
+
1172
+ latents = callback_outputs.pop("latents", latents)
1173
+ # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1174
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1175
+
1176
+ # call the callback, if provided
1177
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1178
+ progress_bar.update()
1179
+
1180
+ # 8. Post-processing
1181
+ if not output_type == "latent":
1182
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
1183
+ 0
1184
+ ]
1185
+ image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
1186
+ else:
1187
+ image = latents
1188
+ has_nsfw_concept = None
1189
+
1190
+ if has_nsfw_concept is None:
1191
+ do_denormalize = [True] * image.shape[0]
1192
+ else:
1193
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1194
+
1195
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1196
+
1197
+ # Offload all models
1198
+ self.maybe_free_model_hooks()
1199
+
1200
+ if not return_dict:
1201
+ return (image, has_nsfw_concept)
1202
+
1203
+ return LEditsPPDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1204
+
1205
+ @torch.no_grad()
1206
+ def invert(
1207
+ self,
1208
+ image: PipelineImageInput,
1209
+ source_prompt: str = "",
1210
+ source_guidance_scale: float = 3.5,
1211
+ num_inversion_steps: int = 30,
1212
+ skip: float = 0.15,
1213
+ generator: Optional[torch.Generator] = None,
1214
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1215
+ clip_skip: Optional[int] = None,
1216
+ height: Optional[int] = None,
1217
+ width: Optional[int] = None,
1218
+ resize_mode: Optional[str] = "default",
1219
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
1220
+ ):
1221
+ r"""
1222
+ The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
1223
+ If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140)
1224
+ will be performed instead.
1225
+
1226
+ Args:
1227
+ image (`PipelineImageInput`):
1228
+ Input for the image(s) that are to be edited. Multiple input images have to default to the same aspect
1229
+ ratio.
1230
+ source_prompt (`str`, defaults to `""`):
1231
+ Prompt describing the input image that will be used for guidance during inversion. Guidance is disabled
1232
+ if the `source_prompt` is `""`.
1233
+ source_guidance_scale (`float`, defaults to `3.5`):
1234
+ Strength of guidance during inversion.
1235
+ num_inversion_steps (`int`, defaults to `30`):
1236
+ Number of total performed inversion steps after discarding the initial `skip` steps.
1237
+ skip (`float`, defaults to `0.15`):
1238
+ Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values
1239
+ will lead to stronger changes to the input image. `skip` has to be between `0` and `1`.
1240
+ generator (`torch.Generator`, *optional*):
1241
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1242
+ inversion deterministic.
1243
+ cross_attention_kwargs (`dict`, *optional*):
1244
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1245
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1246
+ clip_skip (`int`, *optional*):
1247
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1248
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1249
+ height (`int`, *optional*, defaults to `None`):
1250
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
1251
+ width (`int`, *optional*`, defaults to `None`):
1252
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
1253
+ resize_mode (`str`, *optional*, defaults to `default`):
1254
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
1255
+ within the specified width and height, and it may not maintaining the original aspect ratio.
1256
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
1257
+ within the dimensions, filling empty with data from image.
1258
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
1259
+ within the dimensions, cropping the excess.
1260
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
1261
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
1262
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
1263
+
1264
+ Returns:
1265
+ [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]:
1266
+ Output will contain the resized input image(s) and respective VAE reconstruction(s).
1267
+ """
1268
+ # Reset attn processor, we do not want to store attn maps during inversion
1269
+ self.unet.set_attn_processor(AttnProcessor())
1270
+
1271
+ self.eta = 1.0
1272
+
1273
+ self.scheduler.config.timestep_spacing = "leading"
1274
+ self.scheduler.set_timesteps(int(num_inversion_steps * (1 + skip)))
1275
+ self.inversion_steps = self.scheduler.timesteps[-num_inversion_steps:]
1276
+ timesteps = self.inversion_steps
1277
+
1278
+ # 1. encode image
1279
+ x0, resized = self.encode_image(
1280
+ image,
1281
+ dtype=self.text_encoder.dtype,
1282
+ height=height,
1283
+ width=width,
1284
+ resize_mode=resize_mode,
1285
+ crops_coords=crops_coords,
1286
+ )
1287
+ self.batch_size = x0.shape[0]
1288
+
1289
+ # autoencoder reconstruction
1290
+ image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
1291
+ image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1292
+
1293
+ # 2. get embeddings
1294
+ do_classifier_free_guidance = source_guidance_scale > 1.0
1295
+
1296
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1297
+
1298
+ uncond_embedding, text_embeddings, _ = self.encode_prompt(
1299
+ num_images_per_prompt=1,
1300
+ device=self.device,
1301
+ negative_prompt=None,
1302
+ enable_edit_guidance=do_classifier_free_guidance,
1303
+ editing_prompt=source_prompt,
1304
+ lora_scale=lora_scale,
1305
+ clip_skip=clip_skip,
1306
+ )
1307
+
1308
+ # 3. find zs and xts
1309
+ variance_noise_shape = (num_inversion_steps, *x0.shape)
1310
+
1311
+ # intermediate latents
1312
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
1313
+ xts = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype)
1314
+
1315
+ for t in reversed(timesteps):
1316
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1317
+ noise = randn_tensor(shape=x0.shape, generator=generator, device=self.device, dtype=x0.dtype)
1318
+ xts[idx] = self.scheduler.add_noise(x0, noise, torch.Tensor([t]))
1319
+ xts = torch.cat([x0.unsqueeze(0), xts], dim=0)
1320
+
1321
+ self.scheduler.set_timesteps(len(self.scheduler.timesteps))
1322
+ # noise maps
1323
+ zs = torch.zeros(size=variance_noise_shape, device=self.device, dtype=uncond_embedding.dtype)
1324
+
1325
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
1326
+ for t in timesteps:
1327
+ idx = num_inversion_steps - t_to_idx[int(t)] - 1
1328
+ # 1. predict noise residual
1329
+ xt = xts[idx + 1]
1330
+
1331
+ noise_pred = self.unet(xt, timestep=t, encoder_hidden_states=uncond_embedding).sample
1332
+
1333
+ if not source_prompt == "":
1334
+ noise_pred_cond = self.unet(xt, timestep=t, encoder_hidden_states=text_embeddings).sample
1335
+ noise_pred = noise_pred + source_guidance_scale * (noise_pred_cond - noise_pred)
1336
+
1337
+ xtm1 = xts[idx]
1338
+ z, xtm1_corrected = compute_noise(self.scheduler, xtm1, xt, t, noise_pred, self.eta)
1339
+ zs[idx] = z
1340
+
1341
+ # correction to avoid error accumulation
1342
+ xts[idx] = xtm1_corrected
1343
+
1344
+ progress_bar.update()
1345
+
1346
+ self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
1347
+ zs = zs.flip(0)
1348
+ self.zs = zs
1349
+
1350
+ return LEditsPPInversionPipelineOutput(images=resized, vae_reconstruction_images=image_rec)
1351
+
1352
+ @torch.no_grad()
1353
+ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="default", crops_coords=None):
1354
+ image = self.image_processor.preprocess(
1355
+ image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1356
+ )
1357
+ resized = self.image_processor.postprocess(image=image, output_type="pil")
1358
+
1359
+ if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
1360
+ logger.warning(
1361
+ "Your input images far exceed the default resolution of the underlying diffusion model. "
1362
+ "The output images may contain severe artifacts! "
1363
+ "Consider down-sampling the input using the `height` and `width` parameters"
1364
+ )
1365
+ image = image.to(dtype)
1366
+
1367
+ x0 = self.vae.encode(image.to(self.device)).latent_dist.mode()
1368
+ x0 = x0.to(dtype)
1369
+ x0 = self.vae.config.scaling_factor * x0
1370
+ return x0, resized
1371
+
1372
+
1373
+ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1374
+ # 1. get previous step value (=t-1)
1375
+ prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
1376
+
1377
+ # 2. compute alphas, betas
1378
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
1379
+ alpha_prod_t_prev = (
1380
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
1381
+ )
1382
+
1383
+ beta_prod_t = 1 - alpha_prod_t
1384
+
1385
+ # 3. compute predicted original sample from predicted noise also called
1386
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1387
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
1388
+
1389
+ # 4. Clip "predicted x_0"
1390
+ if scheduler.config.clip_sample:
1391
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
1392
+
1393
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
1394
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
1395
+ variance = scheduler._get_variance(timestep, prev_timestep)
1396
+ std_dev_t = eta * variance ** (0.5)
1397
+
1398
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
1399
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
1400
+
1401
+ # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
1402
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
1403
+ if variance > 0.0:
1404
+ noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
1405
+ else:
1406
+ noise = torch.tensor([0.0]).to(latents.device)
1407
+
1408
+ return noise, mu_xt + (eta * variance**0.5) * noise
1409
+
1410
+
1411
+ def compute_noise_sde_dpm_pp_2nd(scheduler, prev_latents, latents, timestep, noise_pred, eta):
1412
+ def first_order_update(model_output, sample): # timestep, prev_timestep, sample):
1413
+ sigma_t, sigma_s = scheduler.sigmas[scheduler.step_index + 1], scheduler.sigmas[scheduler.step_index]
1414
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1415
+ alpha_s, sigma_s = scheduler._sigma_to_alpha_sigma_t(sigma_s)
1416
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1417
+ lambda_s = torch.log(alpha_s) - torch.log(sigma_s)
1418
+
1419
+ h = lambda_t - lambda_s
1420
+
1421
+ mu_xt = (sigma_t / sigma_s * torch.exp(-h)) * sample + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
1422
+
1423
+ mu_xt = scheduler.dpm_solver_first_order_update(
1424
+ model_output=model_output, sample=sample, noise=torch.zeros_like(sample)
1425
+ )
1426
+
1427
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1428
+ if sigma > 0.0:
1429
+ noise = (prev_latents - mu_xt) / sigma
1430
+ else:
1431
+ noise = torch.tensor([0.0]).to(sample.device)
1432
+
1433
+ prev_sample = mu_xt + sigma * noise
1434
+ return noise, prev_sample
1435
+
1436
+ def second_order_update(model_output_list, sample): # timestep_list, prev_timestep, sample):
1437
+ sigma_t, sigma_s0, sigma_s1 = (
1438
+ scheduler.sigmas[scheduler.step_index + 1],
1439
+ scheduler.sigmas[scheduler.step_index],
1440
+ scheduler.sigmas[scheduler.step_index - 1],
1441
+ )
1442
+
1443
+ alpha_t, sigma_t = scheduler._sigma_to_alpha_sigma_t(sigma_t)
1444
+ alpha_s0, sigma_s0 = scheduler._sigma_to_alpha_sigma_t(sigma_s0)
1445
+ alpha_s1, sigma_s1 = scheduler._sigma_to_alpha_sigma_t(sigma_s1)
1446
+
1447
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
1448
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
1449
+ lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1)
1450
+
1451
+ m0, m1 = model_output_list[-1], model_output_list[-2]
1452
+
1453
+ h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
1454
+ r0 = h_0 / h
1455
+ D0, D1 = m0, (1.0 / r0) * (m0 - m1)
1456
+
1457
+ mu_xt = (
1458
+ (sigma_t / sigma_s0 * torch.exp(-h)) * sample
1459
+ + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
1460
+ + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
1461
+ )
1462
+
1463
+ sigma = sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h))
1464
+ if sigma > 0.0:
1465
+ noise = (prev_latents - mu_xt) / sigma
1466
+ else:
1467
+ noise = torch.tensor([0.0]).to(sample.device)
1468
+
1469
+ prev_sample = mu_xt + sigma * noise
1470
+
1471
+ return noise, prev_sample
1472
+
1473
+ if scheduler.step_index is None:
1474
+ scheduler._init_step_index(timestep)
1475
+
1476
+ model_output = scheduler.convert_model_output(model_output=noise_pred, sample=latents)
1477
+ for i in range(scheduler.config.solver_order - 1):
1478
+ scheduler.model_outputs[i] = scheduler.model_outputs[i + 1]
1479
+ scheduler.model_outputs[-1] = model_output
1480
+
1481
+ if scheduler.lower_order_nums < 1:
1482
+ noise, prev_sample = first_order_update(model_output, latents)
1483
+ else:
1484
+ noise, prev_sample = second_order_update(scheduler.model_outputs, latents)
1485
+
1486
+ if scheduler.lower_order_nums < scheduler.config.solver_order:
1487
+ scheduler.lower_order_nums += 1
1488
+
1489
+ # upon completion increase step index by one
1490
+ scheduler._step_index += 1
1491
+
1492
+ return noise, prev_sample
1493
+
1494
+
1495
+ def compute_noise(scheduler, *args):
1496
+ if isinstance(scheduler, DDIMScheduler):
1497
+ return compute_noise_ddim(scheduler, *args)
1498
+ elif (
1499
+ isinstance(scheduler, DPMSolverMultistepScheduler)
1500
+ and scheduler.config.algorithm_type == "sde-dpmsolver++"
1501
+ and scheduler.config.solver_order == 2
1502
+ ):
1503
+ return compute_noise_sde_dpm_pp_2nd(scheduler, *args)
1504
+ else:
1505
+ raise NotImplementedError