diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import inspect
14
15
  from importlib import import_module
15
16
  from typing import Callable, Optional, Union
16
17
 
@@ -18,10 +19,11 @@ import torch
18
19
  import torch.nn.functional as F
19
20
  from torch import nn
20
21
 
21
- from ..utils import USE_PEFT_BACKEND, deprecate, logging
22
+ from ..image_processor import IPAdapterMaskProcessor
23
+ from ..utils import deprecate, logging
22
24
  from ..utils.import_utils import is_xformers_available
23
25
  from ..utils.torch_utils import maybe_allow_in_graph
24
- from .lora import LoRACompatibleLinear, LoRALinearLayer
26
+ from .lora import LoRALinearLayer
25
27
 
26
28
 
27
29
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -114,6 +116,8 @@ class Attention(nn.Module):
114
116
  super().__init__()
115
117
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
116
118
  self.query_dim = query_dim
119
+ self.use_bias = bias
120
+ self.is_cross_attention = cross_attention_dim is not None
117
121
  self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
118
122
  self.upcast_attention = upcast_attention
119
123
  self.upcast_softmax = upcast_softmax
@@ -177,10 +181,7 @@ class Attention(nn.Module):
177
181
  f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
178
182
  )
179
183
 
180
- if USE_PEFT_BACKEND:
181
- linear_cls = nn.Linear
182
- else:
183
- linear_cls = LoRACompatibleLinear
184
+ linear_cls = nn.Linear
184
185
 
185
186
  self.linear_cls = linear_cls
186
187
  self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
@@ -509,6 +510,15 @@ class Attention(nn.Module):
509
510
  # The `Attention` class can call different attention processors / attention functions
510
511
  # here we simply pass along all tensors to the selected processor class
511
512
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
513
+
514
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
515
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
516
+ if len(unused_kwargs) > 0:
517
+ logger.warning(
518
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
519
+ )
520
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
521
+
512
522
  return self.processor(
513
523
  self,
514
524
  hidden_states,
@@ -548,12 +558,16 @@ class Attention(nn.Module):
548
558
  `torch.Tensor`: The reshaped tensor.
549
559
  """
550
560
  head_size = self.heads
551
- batch_size, seq_len, dim = tensor.shape
552
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
561
+ if tensor.ndim == 3:
562
+ batch_size, seq_len, dim = tensor.shape
563
+ extra_dim = 1
564
+ else:
565
+ batch_size, extra_dim, seq_len, dim = tensor.shape
566
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
553
567
  tensor = tensor.permute(0, 2, 1, 3)
554
568
 
555
569
  if out_dim == 3:
556
- tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
570
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
557
571
 
558
572
  return tensor
559
573
 
@@ -682,27 +696,32 @@ class Attention(nn.Module):
682
696
 
683
697
  @torch.no_grad()
684
698
  def fuse_projections(self, fuse=True):
685
- is_cross_attention = self.cross_attention_dim != self.query_dim
686
699
  device = self.to_q.weight.data.device
687
700
  dtype = self.to_q.weight.data.dtype
688
701
 
689
- if not is_cross_attention:
702
+ if not self.is_cross_attention:
690
703
  # fetch weight matrices.
691
704
  concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
692
705
  in_features = concatenated_weights.shape[1]
693
706
  out_features = concatenated_weights.shape[0]
694
707
 
695
708
  # create a new single projection layer and copy over the weights.
696
- self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
709
+ self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
697
710
  self.to_qkv.weight.copy_(concatenated_weights)
711
+ if self.use_bias:
712
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
713
+ self.to_qkv.bias.copy_(concatenated_bias)
698
714
 
699
715
  else:
700
716
  concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
701
717
  in_features = concatenated_weights.shape[1]
702
718
  out_features = concatenated_weights.shape[0]
703
719
 
704
- self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
720
+ self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
705
721
  self.to_kv.weight.copy_(concatenated_weights)
722
+ if self.use_bias:
723
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
724
+ self.to_kv.bias.copy_(concatenated_bias)
706
725
 
707
726
  self.fused_projections = fuse
708
727
 
@@ -719,11 +738,14 @@ class AttnProcessor:
719
738
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
720
739
  attention_mask: Optional[torch.FloatTensor] = None,
721
740
  temb: Optional[torch.FloatTensor] = None,
722
- scale: float = 1.0,
741
+ *args,
742
+ **kwargs,
723
743
  ) -> torch.Tensor:
724
- residual = hidden_states
744
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
745
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
746
+ deprecate("scale", "1.0.0", deprecation_message)
725
747
 
726
- args = () if USE_PEFT_BACKEND else (scale,)
748
+ residual = hidden_states
727
749
 
728
750
  if attn.spatial_norm is not None:
729
751
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -742,15 +764,15 @@ class AttnProcessor:
742
764
  if attn.group_norm is not None:
743
765
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
744
766
 
745
- query = attn.to_q(hidden_states, *args)
767
+ query = attn.to_q(hidden_states)
746
768
 
747
769
  if encoder_hidden_states is None:
748
770
  encoder_hidden_states = hidden_states
749
771
  elif attn.norm_cross:
750
772
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
751
773
 
752
- key = attn.to_k(encoder_hidden_states, *args)
753
- value = attn.to_v(encoder_hidden_states, *args)
774
+ key = attn.to_k(encoder_hidden_states)
775
+ value = attn.to_v(encoder_hidden_states)
754
776
 
755
777
  query = attn.head_to_batch_dim(query)
756
778
  key = attn.head_to_batch_dim(key)
@@ -761,7 +783,7 @@ class AttnProcessor:
761
783
  hidden_states = attn.batch_to_head_dim(hidden_states)
762
784
 
763
785
  # linear proj
764
- hidden_states = attn.to_out[0](hidden_states, *args)
786
+ hidden_states = attn.to_out[0](hidden_states)
765
787
  # dropout
766
788
  hidden_states = attn.to_out[1](hidden_states)
767
789
 
@@ -892,11 +914,14 @@ class AttnAddedKVProcessor:
892
914
  hidden_states: torch.FloatTensor,
893
915
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
894
916
  attention_mask: Optional[torch.FloatTensor] = None,
895
- scale: float = 1.0,
917
+ *args,
918
+ **kwargs,
896
919
  ) -> torch.Tensor:
897
- residual = hidden_states
920
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
921
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
922
+ deprecate("scale", "1.0.0", deprecation_message)
898
923
 
899
- args = () if USE_PEFT_BACKEND else (scale,)
924
+ residual = hidden_states
900
925
 
901
926
  hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
902
927
  batch_size, sequence_length, _ = hidden_states.shape
@@ -910,17 +935,17 @@ class AttnAddedKVProcessor:
910
935
 
911
936
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
912
937
 
913
- query = attn.to_q(hidden_states, *args)
938
+ query = attn.to_q(hidden_states)
914
939
  query = attn.head_to_batch_dim(query)
915
940
 
916
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
917
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
941
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
942
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
918
943
  encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
919
944
  encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
920
945
 
921
946
  if not attn.only_cross_attention:
922
- key = attn.to_k(hidden_states, *args)
923
- value = attn.to_v(hidden_states, *args)
947
+ key = attn.to_k(hidden_states)
948
+ value = attn.to_v(hidden_states)
924
949
  key = attn.head_to_batch_dim(key)
925
950
  value = attn.head_to_batch_dim(value)
926
951
  key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -934,7 +959,7 @@ class AttnAddedKVProcessor:
934
959
  hidden_states = attn.batch_to_head_dim(hidden_states)
935
960
 
936
961
  # linear proj
937
- hidden_states = attn.to_out[0](hidden_states, *args)
962
+ hidden_states = attn.to_out[0](hidden_states)
938
963
  # dropout
939
964
  hidden_states = attn.to_out[1](hidden_states)
940
965
 
@@ -962,11 +987,14 @@ class AttnAddedKVProcessor2_0:
962
987
  hidden_states: torch.FloatTensor,
963
988
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
964
989
  attention_mask: Optional[torch.FloatTensor] = None,
965
- scale: float = 1.0,
990
+ *args,
991
+ **kwargs,
966
992
  ) -> torch.Tensor:
967
- residual = hidden_states
993
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
994
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
995
+ deprecate("scale", "1.0.0", deprecation_message)
968
996
 
969
- args = () if USE_PEFT_BACKEND else (scale,)
997
+ residual = hidden_states
970
998
 
971
999
  hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
972
1000
  batch_size, sequence_length, _ = hidden_states.shape
@@ -980,7 +1008,7 @@ class AttnAddedKVProcessor2_0:
980
1008
 
981
1009
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
982
1010
 
983
- query = attn.to_q(hidden_states, *args)
1011
+ query = attn.to_q(hidden_states)
984
1012
  query = attn.head_to_batch_dim(query, out_dim=4)
985
1013
 
986
1014
  encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -989,8 +1017,8 @@ class AttnAddedKVProcessor2_0:
989
1017
  encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
990
1018
 
991
1019
  if not attn.only_cross_attention:
992
- key = attn.to_k(hidden_states, *args)
993
- value = attn.to_v(hidden_states, *args)
1020
+ key = attn.to_k(hidden_states)
1021
+ value = attn.to_v(hidden_states)
994
1022
  key = attn.head_to_batch_dim(key, out_dim=4)
995
1023
  value = attn.head_to_batch_dim(value, out_dim=4)
996
1024
  key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -1007,7 +1035,7 @@ class AttnAddedKVProcessor2_0:
1007
1035
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
1008
1036
 
1009
1037
  # linear proj
1010
- hidden_states = attn.to_out[0](hidden_states, *args)
1038
+ hidden_states = attn.to_out[0](hidden_states)
1011
1039
  # dropout
1012
1040
  hidden_states = attn.to_out[1](hidden_states)
1013
1041
 
@@ -1110,11 +1138,14 @@ class XFormersAttnProcessor:
1110
1138
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
1111
1139
  attention_mask: Optional[torch.FloatTensor] = None,
1112
1140
  temb: Optional[torch.FloatTensor] = None,
1113
- scale: float = 1.0,
1141
+ *args,
1142
+ **kwargs,
1114
1143
  ) -> torch.FloatTensor:
1115
- residual = hidden_states
1144
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1145
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1146
+ deprecate("scale", "1.0.0", deprecation_message)
1116
1147
 
1117
- args = () if USE_PEFT_BACKEND else (scale,)
1148
+ residual = hidden_states
1118
1149
 
1119
1150
  if attn.spatial_norm is not None:
1120
1151
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1143,15 +1174,15 @@ class XFormersAttnProcessor:
1143
1174
  if attn.group_norm is not None:
1144
1175
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1145
1176
 
1146
- query = attn.to_q(hidden_states, *args)
1177
+ query = attn.to_q(hidden_states)
1147
1178
 
1148
1179
  if encoder_hidden_states is None:
1149
1180
  encoder_hidden_states = hidden_states
1150
1181
  elif attn.norm_cross:
1151
1182
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1152
1183
 
1153
- key = attn.to_k(encoder_hidden_states, *args)
1154
- value = attn.to_v(encoder_hidden_states, *args)
1184
+ key = attn.to_k(encoder_hidden_states)
1185
+ value = attn.to_v(encoder_hidden_states)
1155
1186
 
1156
1187
  query = attn.head_to_batch_dim(query).contiguous()
1157
1188
  key = attn.head_to_batch_dim(key).contiguous()
@@ -1164,7 +1195,7 @@ class XFormersAttnProcessor:
1164
1195
  hidden_states = attn.batch_to_head_dim(hidden_states)
1165
1196
 
1166
1197
  # linear proj
1167
- hidden_states = attn.to_out[0](hidden_states, *args)
1198
+ hidden_states = attn.to_out[0](hidden_states)
1168
1199
  # dropout
1169
1200
  hidden_states = attn.to_out[1](hidden_states)
1170
1201
 
@@ -1195,8 +1226,13 @@ class AttnProcessor2_0:
1195
1226
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
1196
1227
  attention_mask: Optional[torch.FloatTensor] = None,
1197
1228
  temb: Optional[torch.FloatTensor] = None,
1198
- scale: float = 1.0,
1229
+ *args,
1230
+ **kwargs,
1199
1231
  ) -> torch.FloatTensor:
1232
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1233
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1234
+ deprecate("scale", "1.0.0", deprecation_message)
1235
+
1200
1236
  residual = hidden_states
1201
1237
  if attn.spatial_norm is not None:
1202
1238
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1220,16 +1256,15 @@ class AttnProcessor2_0:
1220
1256
  if attn.group_norm is not None:
1221
1257
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1222
1258
 
1223
- args = () if USE_PEFT_BACKEND else (scale,)
1224
- query = attn.to_q(hidden_states, *args)
1259
+ query = attn.to_q(hidden_states)
1225
1260
 
1226
1261
  if encoder_hidden_states is None:
1227
1262
  encoder_hidden_states = hidden_states
1228
1263
  elif attn.norm_cross:
1229
1264
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1230
1265
 
1231
- key = attn.to_k(encoder_hidden_states, *args)
1232
- value = attn.to_v(encoder_hidden_states, *args)
1266
+ key = attn.to_k(encoder_hidden_states)
1267
+ value = attn.to_v(encoder_hidden_states)
1233
1268
 
1234
1269
  inner_dim = key.shape[-1]
1235
1270
  head_dim = inner_dim // attn.heads
@@ -1249,7 +1284,7 @@ class AttnProcessor2_0:
1249
1284
  hidden_states = hidden_states.to(query.dtype)
1250
1285
 
1251
1286
  # linear proj
1252
- hidden_states = attn.to_out[0](hidden_states, *args)
1287
+ hidden_states = attn.to_out[0](hidden_states)
1253
1288
  # dropout
1254
1289
  hidden_states = attn.to_out[1](hidden_states)
1255
1290
 
@@ -1290,8 +1325,13 @@ class FusedAttnProcessor2_0:
1290
1325
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
1291
1326
  attention_mask: Optional[torch.FloatTensor] = None,
1292
1327
  temb: Optional[torch.FloatTensor] = None,
1293
- scale: float = 1.0,
1328
+ *args,
1329
+ **kwargs,
1294
1330
  ) -> torch.FloatTensor:
1331
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1332
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1333
+ deprecate("scale", "1.0.0", deprecation_message)
1334
+
1295
1335
  residual = hidden_states
1296
1336
  if attn.spatial_norm is not None:
1297
1337
  hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1315,17 +1355,16 @@ class FusedAttnProcessor2_0:
1315
1355
  if attn.group_norm is not None:
1316
1356
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1317
1357
 
1318
- args = () if USE_PEFT_BACKEND else (scale,)
1319
1358
  if encoder_hidden_states is None:
1320
- qkv = attn.to_qkv(hidden_states, *args)
1359
+ qkv = attn.to_qkv(hidden_states)
1321
1360
  split_size = qkv.shape[-1] // 3
1322
1361
  query, key, value = torch.split(qkv, split_size, dim=-1)
1323
1362
  else:
1324
1363
  if attn.norm_cross:
1325
1364
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1326
- query = attn.to_q(hidden_states, *args)
1365
+ query = attn.to_q(hidden_states)
1327
1366
 
1328
- kv = attn.to_kv(encoder_hidden_states, *args)
1367
+ kv = attn.to_kv(encoder_hidden_states)
1329
1368
  split_size = kv.shape[-1] // 2
1330
1369
  key, value = torch.split(kv, split_size, dim=-1)
1331
1370
 
@@ -1346,7 +1385,7 @@ class FusedAttnProcessor2_0:
1346
1385
  hidden_states = hidden_states.to(query.dtype)
1347
1386
 
1348
1387
  # linear proj
1349
- hidden_states = attn.to_out[0](hidden_states, *args)
1388
+ hidden_states = attn.to_out[0](hidden_states)
1350
1389
  # dropout
1351
1390
  hidden_states = attn.to_out[1](hidden_states)
1352
1391
 
@@ -1799,24 +1838,7 @@ class SpatialNorm(nn.Module):
1799
1838
  return new_f
1800
1839
 
1801
1840
 
1802
- ## Deprecated
1803
1841
  class LoRAAttnProcessor(nn.Module):
1804
- r"""
1805
- Processor for implementing the LoRA attention mechanism.
1806
-
1807
- Args:
1808
- hidden_size (`int`, *optional*):
1809
- The hidden size of the attention layer.
1810
- cross_attention_dim (`int`, *optional*):
1811
- The number of channels in the `encoder_hidden_states`.
1812
- rank (`int`, defaults to 4):
1813
- The dimension of the LoRA update matrices.
1814
- network_alpha (`int`, *optional*):
1815
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1816
- kwargs (`dict`):
1817
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1818
- """
1819
-
1820
1842
  def __init__(
1821
1843
  self,
1822
1844
  hidden_size: int,
@@ -1825,6 +1847,9 @@ class LoRAAttnProcessor(nn.Module):
1825
1847
  network_alpha: Optional[int] = None,
1826
1848
  **kwargs,
1827
1849
  ):
1850
+ deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
1851
+ deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
1852
+
1828
1853
  super().__init__()
1829
1854
 
1830
1855
  self.hidden_size = hidden_size
@@ -1851,7 +1876,7 @@ class LoRAAttnProcessor(nn.Module):
1851
1876
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1852
1877
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1853
1878
 
1854
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1879
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
1855
1880
  self_cls_name = self.__class__.__name__
1856
1881
  deprecate(
1857
1882
  self_cls_name,
@@ -1869,27 +1894,10 @@ class LoRAAttnProcessor(nn.Module):
1869
1894
 
1870
1895
  attn._modules.pop("processor")
1871
1896
  attn.processor = AttnProcessor()
1872
- return attn.processor(attn, hidden_states, *args, **kwargs)
1897
+ return attn.processor(attn, hidden_states, **kwargs)
1873
1898
 
1874
1899
 
1875
1900
  class LoRAAttnProcessor2_0(nn.Module):
1876
- r"""
1877
- Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1878
- attention.
1879
-
1880
- Args:
1881
- hidden_size (`int`):
1882
- The hidden size of the attention layer.
1883
- cross_attention_dim (`int`, *optional*):
1884
- The number of channels in the `encoder_hidden_states`.
1885
- rank (`int`, defaults to 4):
1886
- The dimension of the LoRA update matrices.
1887
- network_alpha (`int`, *optional*):
1888
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1889
- kwargs (`dict`):
1890
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
1891
- """
1892
-
1893
1901
  def __init__(
1894
1902
  self,
1895
1903
  hidden_size: int,
@@ -1898,6 +1906,9 @@ class LoRAAttnProcessor2_0(nn.Module):
1898
1906
  network_alpha: Optional[int] = None,
1899
1907
  **kwargs,
1900
1908
  ):
1909
+ deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
1910
+ deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
1911
+
1901
1912
  super().__init__()
1902
1913
  if not hasattr(F, "scaled_dot_product_attention"):
1903
1914
  raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -1926,7 +1937,7 @@ class LoRAAttnProcessor2_0(nn.Module):
1926
1937
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1927
1938
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1928
1939
 
1929
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1940
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
1930
1941
  self_cls_name = self.__class__.__name__
1931
1942
  deprecate(
1932
1943
  self_cls_name,
@@ -1944,7 +1955,7 @@ class LoRAAttnProcessor2_0(nn.Module):
1944
1955
 
1945
1956
  attn._modules.pop("processor")
1946
1957
  attn.processor = AttnProcessor2_0()
1947
- return attn.processor(attn, hidden_states, *args, **kwargs)
1958
+ return attn.processor(attn, hidden_states, **kwargs)
1948
1959
 
1949
1960
 
1950
1961
  class LoRAXFormersAttnProcessor(nn.Module):
@@ -2005,7 +2016,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
2005
2016
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2006
2017
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2007
2018
 
2008
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
2019
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2009
2020
  self_cls_name = self.__class__.__name__
2010
2021
  deprecate(
2011
2022
  self_cls_name,
@@ -2023,7 +2034,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
2023
2034
 
2024
2035
  attn._modules.pop("processor")
2025
2036
  attn.processor = XFormersAttnProcessor()
2026
- return attn.processor(attn, hidden_states, *args, **kwargs)
2037
+ return attn.processor(attn, hidden_states, **kwargs)
2027
2038
 
2028
2039
 
2029
2040
  class LoRAAttnAddedKVProcessor(nn.Module):
@@ -2064,7 +2075,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
2064
2075
  self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2065
2076
  self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2066
2077
 
2067
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
2078
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2068
2079
  self_cls_name = self.__class__.__name__
2069
2080
  deprecate(
2070
2081
  self_cls_name,
@@ -2082,7 +2093,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
2082
2093
 
2083
2094
  attn._modules.pop("processor")
2084
2095
  attn.processor = AttnAddedKVProcessor()
2085
- return attn.processor(attn, hidden_states, *args, **kwargs)
2096
+ return attn.processor(attn, hidden_states, **kwargs)
2086
2097
 
2087
2098
 
2088
2099
  class IPAdapterAttnProcessor(nn.Module):
@@ -2125,12 +2136,13 @@ class IPAdapterAttnProcessor(nn.Module):
2125
2136
 
2126
2137
  def __call__(
2127
2138
  self,
2128
- attn,
2129
- hidden_states,
2130
- encoder_hidden_states=None,
2131
- attention_mask=None,
2132
- temb=None,
2133
- scale=1.0,
2139
+ attn: Attention,
2140
+ hidden_states: torch.FloatTensor,
2141
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2142
+ attention_mask: Optional[torch.FloatTensor] = None,
2143
+ temb: Optional[torch.FloatTensor] = None,
2144
+ scale: float = 1.0,
2145
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
2134
2146
  ):
2135
2147
  residual = hidden_states
2136
2148
 
@@ -2185,9 +2197,22 @@ class IPAdapterAttnProcessor(nn.Module):
2185
2197
  hidden_states = torch.bmm(attention_probs, value)
2186
2198
  hidden_states = attn.batch_to_head_dim(hidden_states)
2187
2199
 
2200
+ if ip_adapter_masks is not None:
2201
+ if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2202
+ raise ValueError(
2203
+ " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2204
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2205
+ )
2206
+ if len(ip_adapter_masks) != len(self.scale):
2207
+ raise ValueError(
2208
+ f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2209
+ )
2210
+ else:
2211
+ ip_adapter_masks = [None] * len(self.scale)
2212
+
2188
2213
  # for ip-adapter
2189
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2190
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2214
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2215
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2191
2216
  ):
2192
2217
  ip_key = to_k_ip(current_ip_hidden_states)
2193
2218
  ip_value = to_v_ip(current_ip_hidden_states)
@@ -2199,6 +2224,15 @@ class IPAdapterAttnProcessor(nn.Module):
2199
2224
  current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2200
2225
  current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
2201
2226
 
2227
+ if mask is not None:
2228
+ mask_downsample = IPAdapterMaskProcessor.downsample(
2229
+ mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2230
+ )
2231
+
2232
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2233
+
2234
+ current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2235
+
2202
2236
  hidden_states = hidden_states + scale * current_ip_hidden_states
2203
2237
 
2204
2238
  # linear proj
@@ -2262,12 +2296,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2262
2296
 
2263
2297
  def __call__(
2264
2298
  self,
2265
- attn,
2266
- hidden_states,
2267
- encoder_hidden_states=None,
2268
- attention_mask=None,
2269
- temb=None,
2270
- scale=1.0,
2299
+ attn: Attention,
2300
+ hidden_states: torch.FloatTensor,
2301
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
2302
+ attention_mask: Optional[torch.FloatTensor] = None,
2303
+ temb: Optional[torch.FloatTensor] = None,
2304
+ scale: float = 1.0,
2305
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
2271
2306
  ):
2272
2307
  residual = hidden_states
2273
2308
 
@@ -2336,9 +2371,22 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2336
2371
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2337
2372
  hidden_states = hidden_states.to(query.dtype)
2338
2373
 
2374
+ if ip_adapter_masks is not None:
2375
+ if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2376
+ raise ValueError(
2377
+ " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2378
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2379
+ )
2380
+ if len(ip_adapter_masks) != len(self.scale):
2381
+ raise ValueError(
2382
+ f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2383
+ )
2384
+ else:
2385
+ ip_adapter_masks = [None] * len(self.scale)
2386
+
2339
2387
  # for ip-adapter
2340
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2341
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2388
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2389
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2342
2390
  ):
2343
2391
  ip_key = to_k_ip(current_ip_hidden_states)
2344
2392
  ip_value = to_v_ip(current_ip_hidden_states)
@@ -2357,6 +2405,15 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2357
2405
  )
2358
2406
  current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2359
2407
 
2408
+ if mask is not None:
2409
+ mask_downsample = IPAdapterMaskProcessor.downsample(
2410
+ mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2411
+ )
2412
+
2413
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2414
+
2415
+ current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2416
+
2360
2417
  hidden_states = hidden_states + scale * current_ip_hidden_states
2361
2418
 
2362
2419
  # linear proj
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -80,6 +80,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
80
80
  norm_num_groups: int = 32,
81
81
  sample_size: int = 32,
82
82
  scaling_factor: float = 0.18215,
83
+ latents_mean: Optional[Tuple[float]] = None,
84
+ latents_std: Optional[Tuple[float]] = None,
83
85
  force_upcast: float = True,
84
86
  ):
85
87
  super().__init__()
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.