diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@ import torch
18
18
  import torch.nn.functional as F
19
19
  from torch import nn
20
20
 
21
- from ...utils import is_torch_version, logging
21
+ from ...utils import deprecate, is_torch_version, logging
22
22
  from ...utils.torch_utils import apply_freeu
23
23
  from ..activations import get_activation
24
24
  from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
@@ -69,7 +69,7 @@ def get_down_block(
69
69
  ):
70
70
  # If attn head dim is not defined, we default it to the number of heads
71
71
  if attention_head_dim is None:
72
- logger.warn(
72
+ logger.warning(
73
73
  f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
74
74
  )
75
75
  attention_head_dim = num_attention_heads
@@ -249,6 +249,81 @@ def get_down_block(
249
249
  raise ValueError(f"{down_block_type} does not exist.")
250
250
 
251
251
 
252
+ def get_mid_block(
253
+ mid_block_type: str,
254
+ temb_channels: int,
255
+ in_channels: int,
256
+ resnet_eps: float,
257
+ resnet_act_fn: str,
258
+ resnet_groups: int,
259
+ output_scale_factor: float = 1.0,
260
+ transformer_layers_per_block: int = 1,
261
+ num_attention_heads: Optional[int] = None,
262
+ cross_attention_dim: Optional[int] = None,
263
+ dual_cross_attention: bool = False,
264
+ use_linear_projection: bool = False,
265
+ mid_block_only_cross_attention: bool = False,
266
+ upcast_attention: bool = False,
267
+ resnet_time_scale_shift: str = "default",
268
+ attention_type: str = "default",
269
+ resnet_skip_time_act: bool = False,
270
+ cross_attention_norm: Optional[str] = None,
271
+ attention_head_dim: Optional[int] = 1,
272
+ dropout: float = 0.0,
273
+ ):
274
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
275
+ return UNetMidBlock2DCrossAttn(
276
+ transformer_layers_per_block=transformer_layers_per_block,
277
+ in_channels=in_channels,
278
+ temb_channels=temb_channels,
279
+ dropout=dropout,
280
+ resnet_eps=resnet_eps,
281
+ resnet_act_fn=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ resnet_time_scale_shift=resnet_time_scale_shift,
284
+ cross_attention_dim=cross_attention_dim,
285
+ num_attention_heads=num_attention_heads,
286
+ resnet_groups=resnet_groups,
287
+ dual_cross_attention=dual_cross_attention,
288
+ use_linear_projection=use_linear_projection,
289
+ upcast_attention=upcast_attention,
290
+ attention_type=attention_type,
291
+ )
292
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
293
+ return UNetMidBlock2DSimpleCrossAttn(
294
+ in_channels=in_channels,
295
+ temb_channels=temb_channels,
296
+ dropout=dropout,
297
+ resnet_eps=resnet_eps,
298
+ resnet_act_fn=resnet_act_fn,
299
+ output_scale_factor=output_scale_factor,
300
+ cross_attention_dim=cross_attention_dim,
301
+ attention_head_dim=attention_head_dim,
302
+ resnet_groups=resnet_groups,
303
+ resnet_time_scale_shift=resnet_time_scale_shift,
304
+ skip_time_act=resnet_skip_time_act,
305
+ only_cross_attention=mid_block_only_cross_attention,
306
+ cross_attention_norm=cross_attention_norm,
307
+ )
308
+ elif mid_block_type == "UNetMidBlock2D":
309
+ return UNetMidBlock2D(
310
+ in_channels=in_channels,
311
+ temb_channels=temb_channels,
312
+ dropout=dropout,
313
+ num_layers=0,
314
+ resnet_eps=resnet_eps,
315
+ resnet_act_fn=resnet_act_fn,
316
+ output_scale_factor=output_scale_factor,
317
+ resnet_groups=resnet_groups,
318
+ resnet_time_scale_shift=resnet_time_scale_shift,
319
+ add_attention=False,
320
+ )
321
+ elif mid_block_type is None:
322
+ return None
323
+ else:
324
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
325
+
326
+
252
327
  def get_up_block(
253
328
  up_block_type: str,
254
329
  num_layers: int,
@@ -279,7 +354,7 @@ def get_up_block(
279
354
  ) -> nn.Module:
280
355
  # If attn head dim is not defined, we default it to the number of heads
281
356
  if attention_head_dim is None:
282
- logger.warn(
357
+ logger.warning(
283
358
  f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
284
359
  )
285
360
  attention_head_dim = num_attention_heads
@@ -598,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
598
673
  attentions = []
599
674
 
600
675
  if attention_head_dim is None:
601
- logger.warn(
676
+ logger.warning(
602
677
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
603
678
  )
604
679
  attention_head_dim = in_channels
@@ -769,8 +844,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
769
844
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
770
845
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
771
846
  ) -> torch.FloatTensor:
772
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
773
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
847
+ if cross_attention_kwargs is not None:
848
+ if cross_attention_kwargs.get("scale", None) is not None:
849
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
850
+
851
+ hidden_states = self.resnets[0](hidden_states, temb)
774
852
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
775
853
  if self.training and self.gradient_checkpointing:
776
854
 
@@ -807,7 +885,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
807
885
  encoder_attention_mask=encoder_attention_mask,
808
886
  return_dict=False,
809
887
  )[0]
810
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
888
+ hidden_states = resnet(hidden_states, temb)
811
889
 
812
890
  return hidden_states
813
891
 
@@ -907,7 +985,8 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
907
985
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
908
986
  ) -> torch.FloatTensor:
909
987
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
910
- lora_scale = cross_attention_kwargs.get("scale", 1.0)
988
+ if cross_attention_kwargs.get("scale", None) is not None:
989
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
911
990
 
912
991
  if attention_mask is None:
913
992
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -920,7 +999,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
920
999
  # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
921
1000
  mask = attention_mask
922
1001
 
923
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
1002
+ hidden_states = self.resnets[0](hidden_states, temb)
924
1003
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
925
1004
  # attn
926
1005
  hidden_states = attn(
@@ -931,7 +1010,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
931
1010
  )
932
1011
 
933
1012
  # resnet
934
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1013
+ hidden_states = resnet(hidden_states, temb)
935
1014
 
936
1015
  return hidden_states
937
1016
 
@@ -960,7 +1039,7 @@ class AttnDownBlock2D(nn.Module):
960
1039
  self.downsample_type = downsample_type
961
1040
 
962
1041
  if attention_head_dim is None:
963
- logger.warn(
1042
+ logger.warning(
964
1043
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
965
1044
  )
966
1045
  attention_head_dim = out_channels
@@ -1036,23 +1115,22 @@ class AttnDownBlock2D(nn.Module):
1036
1115
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1037
1116
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1038
1117
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1039
-
1040
- lora_scale = cross_attention_kwargs.get("scale", 1.0)
1118
+ if cross_attention_kwargs.get("scale", None) is not None:
1119
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1041
1120
 
1042
1121
  output_states = ()
1043
1122
 
1044
1123
  for resnet, attn in zip(self.resnets, self.attentions):
1045
- cross_attention_kwargs.update({"scale": lora_scale})
1046
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1124
+ hidden_states = resnet(hidden_states, temb)
1047
1125
  hidden_states = attn(hidden_states, **cross_attention_kwargs)
1048
1126
  output_states = output_states + (hidden_states,)
1049
1127
 
1050
1128
  if self.downsamplers is not None:
1051
1129
  for downsampler in self.downsamplers:
1052
1130
  if self.downsample_type == "resnet":
1053
- hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
1131
+ hidden_states = downsampler(hidden_states, temb=temb)
1054
1132
  else:
1055
- hidden_states = downsampler(hidden_states, scale=lora_scale)
1133
+ hidden_states = downsampler(hidden_states)
1056
1134
 
1057
1135
  output_states += (hidden_states,)
1058
1136
 
@@ -1161,9 +1239,11 @@ class CrossAttnDownBlock2D(nn.Module):
1161
1239
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1162
1240
  additional_residuals: Optional[torch.FloatTensor] = None,
1163
1241
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1164
- output_states = ()
1242
+ if cross_attention_kwargs is not None:
1243
+ if cross_attention_kwargs.get("scale", None) is not None:
1244
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1165
1245
 
1166
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1246
+ output_states = ()
1167
1247
 
1168
1248
  blocks = list(zip(self.resnets, self.attentions))
1169
1249
 
@@ -1195,7 +1275,7 @@ class CrossAttnDownBlock2D(nn.Module):
1195
1275
  return_dict=False,
1196
1276
  )[0]
1197
1277
  else:
1198
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1278
+ hidden_states = resnet(hidden_states, temb)
1199
1279
  hidden_states = attn(
1200
1280
  hidden_states,
1201
1281
  encoder_hidden_states=encoder_hidden_states,
@@ -1213,7 +1293,7 @@ class CrossAttnDownBlock2D(nn.Module):
1213
1293
 
1214
1294
  if self.downsamplers is not None:
1215
1295
  for downsampler in self.downsamplers:
1216
- hidden_states = downsampler(hidden_states, scale=lora_scale)
1296
+ hidden_states = downsampler(hidden_states)
1217
1297
 
1218
1298
  output_states = output_states + (hidden_states,)
1219
1299
 
@@ -1273,8 +1353,12 @@ class DownBlock2D(nn.Module):
1273
1353
  self.gradient_checkpointing = False
1274
1354
 
1275
1355
  def forward(
1276
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
1356
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1277
1357
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1358
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1359
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1360
+ deprecate("scale", "1.0.0", deprecation_message)
1361
+
1278
1362
  output_states = ()
1279
1363
 
1280
1364
  for resnet in self.resnets:
@@ -1295,13 +1379,13 @@ class DownBlock2D(nn.Module):
1295
1379
  create_custom_forward(resnet), hidden_states, temb
1296
1380
  )
1297
1381
  else:
1298
- hidden_states = resnet(hidden_states, temb, scale=scale)
1382
+ hidden_states = resnet(hidden_states, temb)
1299
1383
 
1300
1384
  output_states = output_states + (hidden_states,)
1301
1385
 
1302
1386
  if self.downsamplers is not None:
1303
1387
  for downsampler in self.downsamplers:
1304
- hidden_states = downsampler(hidden_states, scale=scale)
1388
+ hidden_states = downsampler(hidden_states)
1305
1389
 
1306
1390
  output_states = output_states + (hidden_states,)
1307
1391
 
@@ -1372,13 +1456,17 @@ class DownEncoderBlock2D(nn.Module):
1372
1456
  else:
1373
1457
  self.downsamplers = None
1374
1458
 
1375
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
1459
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1460
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1461
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1462
+ deprecate("scale", "1.0.0", deprecation_message)
1463
+
1376
1464
  for resnet in self.resnets:
1377
- hidden_states = resnet(hidden_states, temb=None, scale=scale)
1465
+ hidden_states = resnet(hidden_states, temb=None)
1378
1466
 
1379
1467
  if self.downsamplers is not None:
1380
1468
  for downsampler in self.downsamplers:
1381
- hidden_states = downsampler(hidden_states, scale)
1469
+ hidden_states = downsampler(hidden_states)
1382
1470
 
1383
1471
  return hidden_states
1384
1472
 
@@ -1405,7 +1493,7 @@ class AttnDownEncoderBlock2D(nn.Module):
1405
1493
  attentions = []
1406
1494
 
1407
1495
  if attention_head_dim is None:
1408
- logger.warn(
1496
+ logger.warning(
1409
1497
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
1410
1498
  )
1411
1499
  attention_head_dim = out_channels
@@ -1470,15 +1558,18 @@ class AttnDownEncoderBlock2D(nn.Module):
1470
1558
  else:
1471
1559
  self.downsamplers = None
1472
1560
 
1473
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
1561
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1562
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1563
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1564
+ deprecate("scale", "1.0.0", deprecation_message)
1565
+
1474
1566
  for resnet, attn in zip(self.resnets, self.attentions):
1475
- hidden_states = resnet(hidden_states, temb=None, scale=scale)
1476
- cross_attention_kwargs = {"scale": scale}
1477
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
1567
+ hidden_states = resnet(hidden_states, temb=None)
1568
+ hidden_states = attn(hidden_states)
1478
1569
 
1479
1570
  if self.downsamplers is not None:
1480
1571
  for downsampler in self.downsamplers:
1481
- hidden_states = downsampler(hidden_states, scale)
1572
+ hidden_states = downsampler(hidden_states)
1482
1573
 
1483
1574
  return hidden_states
1484
1575
 
@@ -1504,7 +1595,7 @@ class AttnSkipDownBlock2D(nn.Module):
1504
1595
  self.resnets = nn.ModuleList([])
1505
1596
 
1506
1597
  if attention_head_dim is None:
1507
- logger.warn(
1598
+ logger.warning(
1508
1599
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
1509
1600
  )
1510
1601
  attention_head_dim = out_channels
@@ -1569,18 +1660,22 @@ class AttnSkipDownBlock2D(nn.Module):
1569
1660
  hidden_states: torch.FloatTensor,
1570
1661
  temb: Optional[torch.FloatTensor] = None,
1571
1662
  skip_sample: Optional[torch.FloatTensor] = None,
1572
- scale: float = 1.0,
1663
+ *args,
1664
+ **kwargs,
1573
1665
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1666
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1667
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1668
+ deprecate("scale", "1.0.0", deprecation_message)
1669
+
1574
1670
  output_states = ()
1575
1671
 
1576
1672
  for resnet, attn in zip(self.resnets, self.attentions):
1577
- hidden_states = resnet(hidden_states, temb, scale=scale)
1578
- cross_attention_kwargs = {"scale": scale}
1579
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
1673
+ hidden_states = resnet(hidden_states, temb)
1674
+ hidden_states = attn(hidden_states)
1580
1675
  output_states += (hidden_states,)
1581
1676
 
1582
1677
  if self.downsamplers is not None:
1583
- hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
1678
+ hidden_states = self.resnet_down(hidden_states, temb)
1584
1679
  for downsampler in self.downsamplers:
1585
1680
  skip_sample = downsampler(skip_sample)
1586
1681
 
@@ -1656,16 +1751,21 @@ class SkipDownBlock2D(nn.Module):
1656
1751
  hidden_states: torch.FloatTensor,
1657
1752
  temb: Optional[torch.FloatTensor] = None,
1658
1753
  skip_sample: Optional[torch.FloatTensor] = None,
1659
- scale: float = 1.0,
1754
+ *args,
1755
+ **kwargs,
1660
1756
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1757
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1758
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1759
+ deprecate("scale", "1.0.0", deprecation_message)
1760
+
1661
1761
  output_states = ()
1662
1762
 
1663
1763
  for resnet in self.resnets:
1664
- hidden_states = resnet(hidden_states, temb, scale)
1764
+ hidden_states = resnet(hidden_states, temb)
1665
1765
  output_states += (hidden_states,)
1666
1766
 
1667
1767
  if self.downsamplers is not None:
1668
- hidden_states = self.resnet_down(hidden_states, temb, scale)
1768
+ hidden_states = self.resnet_down(hidden_states, temb)
1669
1769
  for downsampler in self.downsamplers:
1670
1770
  skip_sample = downsampler(skip_sample)
1671
1771
 
@@ -1741,8 +1841,12 @@ class ResnetDownsampleBlock2D(nn.Module):
1741
1841
  self.gradient_checkpointing = False
1742
1842
 
1743
1843
  def forward(
1744
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
1844
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1745
1845
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1846
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1847
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1848
+ deprecate("scale", "1.0.0", deprecation_message)
1849
+
1746
1850
  output_states = ()
1747
1851
 
1748
1852
  for resnet in self.resnets:
@@ -1763,13 +1867,13 @@ class ResnetDownsampleBlock2D(nn.Module):
1763
1867
  create_custom_forward(resnet), hidden_states, temb
1764
1868
  )
1765
1869
  else:
1766
- hidden_states = resnet(hidden_states, temb, scale)
1870
+ hidden_states = resnet(hidden_states, temb)
1767
1871
 
1768
1872
  output_states = output_states + (hidden_states,)
1769
1873
 
1770
1874
  if self.downsamplers is not None:
1771
1875
  for downsampler in self.downsamplers:
1772
- hidden_states = downsampler(hidden_states, temb, scale)
1876
+ hidden_states = downsampler(hidden_states, temb)
1773
1877
 
1774
1878
  output_states = output_states + (hidden_states,)
1775
1879
 
@@ -1880,10 +1984,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1880
1984
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1881
1985
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1882
1986
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1883
- output_states = ()
1884
1987
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1988
+ if cross_attention_kwargs.get("scale", None) is not None:
1989
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1885
1990
 
1886
- lora_scale = cross_attention_kwargs.get("scale", 1.0)
1991
+ output_states = ()
1887
1992
 
1888
1993
  if attention_mask is None:
1889
1994
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -1916,7 +2021,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1916
2021
  **cross_attention_kwargs,
1917
2022
  )
1918
2023
  else:
1919
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2024
+ hidden_states = resnet(hidden_states, temb)
1920
2025
 
1921
2026
  hidden_states = attn(
1922
2027
  hidden_states,
@@ -1929,7 +2034,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1929
2034
 
1930
2035
  if self.downsamplers is not None:
1931
2036
  for downsampler in self.downsamplers:
1932
- hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
2037
+ hidden_states = downsampler(hidden_states, temb)
1933
2038
 
1934
2039
  output_states = output_states + (hidden_states,)
1935
2040
 
@@ -1983,8 +2088,12 @@ class KDownBlock2D(nn.Module):
1983
2088
  self.gradient_checkpointing = False
1984
2089
 
1985
2090
  def forward(
1986
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
2091
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1987
2092
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2093
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2094
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2095
+ deprecate("scale", "1.0.0", deprecation_message)
2096
+
1988
2097
  output_states = ()
1989
2098
 
1990
2099
  for resnet in self.resnets:
@@ -2005,7 +2114,7 @@ class KDownBlock2D(nn.Module):
2005
2114
  create_custom_forward(resnet), hidden_states, temb
2006
2115
  )
2007
2116
  else:
2008
- hidden_states = resnet(hidden_states, temb, scale)
2117
+ hidden_states = resnet(hidden_states, temb)
2009
2118
 
2010
2119
  output_states += (hidden_states,)
2011
2120
 
@@ -2090,8 +2199,11 @@ class KCrossAttnDownBlock2D(nn.Module):
2090
2199
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2091
2200
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
2092
2201
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2202
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
2203
+ if cross_attention_kwargs.get("scale", None) is not None:
2204
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2205
+
2093
2206
  output_states = ()
2094
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
2095
2207
 
2096
2208
  for resnet, attn in zip(self.resnets, self.attentions):
2097
2209
  if self.training and self.gradient_checkpointing:
@@ -2121,7 +2233,7 @@ class KCrossAttnDownBlock2D(nn.Module):
2121
2233
  encoder_attention_mask=encoder_attention_mask,
2122
2234
  )
2123
2235
  else:
2124
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2236
+ hidden_states = resnet(hidden_states, temb)
2125
2237
  hidden_states = attn(
2126
2238
  hidden_states,
2127
2239
  encoder_hidden_states=encoder_hidden_states,
@@ -2169,7 +2281,7 @@ class AttnUpBlock2D(nn.Module):
2169
2281
  self.upsample_type = upsample_type
2170
2282
 
2171
2283
  if attention_head_dim is None:
2172
- logger.warn(
2284
+ logger.warning(
2173
2285
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
2174
2286
  )
2175
2287
  attention_head_dim = out_channels
@@ -2241,24 +2353,28 @@ class AttnUpBlock2D(nn.Module):
2241
2353
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2242
2354
  temb: Optional[torch.FloatTensor] = None,
2243
2355
  upsample_size: Optional[int] = None,
2244
- scale: float = 1.0,
2356
+ *args,
2357
+ **kwargs,
2245
2358
  ) -> torch.FloatTensor:
2359
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2360
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2361
+ deprecate("scale", "1.0.0", deprecation_message)
2362
+
2246
2363
  for resnet, attn in zip(self.resnets, self.attentions):
2247
2364
  # pop res hidden states
2248
2365
  res_hidden_states = res_hidden_states_tuple[-1]
2249
2366
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2250
2367
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2251
2368
 
2252
- hidden_states = resnet(hidden_states, temb, scale=scale)
2253
- cross_attention_kwargs = {"scale": scale}
2254
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
2369
+ hidden_states = resnet(hidden_states, temb)
2370
+ hidden_states = attn(hidden_states)
2255
2371
 
2256
2372
  if self.upsamplers is not None:
2257
2373
  for upsampler in self.upsamplers:
2258
2374
  if self.upsample_type == "resnet":
2259
- hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
2375
+ hidden_states = upsampler(hidden_states, temb=temb)
2260
2376
  else:
2261
- hidden_states = upsampler(hidden_states, scale=scale)
2377
+ hidden_states = upsampler(hidden_states)
2262
2378
 
2263
2379
  return hidden_states
2264
2380
 
@@ -2365,7 +2481,10 @@ class CrossAttnUpBlock2D(nn.Module):
2365
2481
  attention_mask: Optional[torch.FloatTensor] = None,
2366
2482
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
2367
2483
  ) -> torch.FloatTensor:
2368
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
2484
+ if cross_attention_kwargs is not None:
2485
+ if cross_attention_kwargs.get("scale", None) is not None:
2486
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2487
+
2369
2488
  is_freeu_enabled = (
2370
2489
  getattr(self, "s1", None)
2371
2490
  and getattr(self, "s2", None)
@@ -2419,7 +2538,7 @@ class CrossAttnUpBlock2D(nn.Module):
2419
2538
  return_dict=False,
2420
2539
  )[0]
2421
2540
  else:
2422
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2541
+ hidden_states = resnet(hidden_states, temb)
2423
2542
  hidden_states = attn(
2424
2543
  hidden_states,
2425
2544
  encoder_hidden_states=encoder_hidden_states,
@@ -2431,7 +2550,7 @@ class CrossAttnUpBlock2D(nn.Module):
2431
2550
 
2432
2551
  if self.upsamplers is not None:
2433
2552
  for upsampler in self.upsamplers:
2434
- hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
2553
+ hidden_states = upsampler(hidden_states, upsample_size)
2435
2554
 
2436
2555
  return hidden_states
2437
2556
 
@@ -2492,8 +2611,13 @@ class UpBlock2D(nn.Module):
2492
2611
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2493
2612
  temb: Optional[torch.FloatTensor] = None,
2494
2613
  upsample_size: Optional[int] = None,
2495
- scale: float = 1.0,
2614
+ *args,
2615
+ **kwargs,
2496
2616
  ) -> torch.FloatTensor:
2617
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2618
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2619
+ deprecate("scale", "1.0.0", deprecation_message)
2620
+
2497
2621
  is_freeu_enabled = (
2498
2622
  getattr(self, "s1", None)
2499
2623
  and getattr(self, "s2", None)
@@ -2537,11 +2661,11 @@ class UpBlock2D(nn.Module):
2537
2661
  create_custom_forward(resnet), hidden_states, temb
2538
2662
  )
2539
2663
  else:
2540
- hidden_states = resnet(hidden_states, temb, scale=scale)
2664
+ hidden_states = resnet(hidden_states, temb)
2541
2665
 
2542
2666
  if self.upsamplers is not None:
2543
2667
  for upsampler in self.upsamplers:
2544
- hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
2668
+ hidden_states = upsampler(hidden_states, upsample_size)
2545
2669
 
2546
2670
  return hidden_states
2547
2671
 
@@ -2608,11 +2732,9 @@ class UpDecoderBlock2D(nn.Module):
2608
2732
 
2609
2733
  self.resolution_idx = resolution_idx
2610
2734
 
2611
- def forward(
2612
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
2613
- ) -> torch.FloatTensor:
2735
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2614
2736
  for resnet in self.resnets:
2615
- hidden_states = resnet(hidden_states, temb=temb, scale=scale)
2737
+ hidden_states = resnet(hidden_states, temb=temb)
2616
2738
 
2617
2739
  if self.upsamplers is not None:
2618
2740
  for upsampler in self.upsamplers:
@@ -2644,7 +2766,7 @@ class AttnUpDecoderBlock2D(nn.Module):
2644
2766
  attentions = []
2645
2767
 
2646
2768
  if attention_head_dim is None:
2647
- logger.warn(
2769
+ logger.warning(
2648
2770
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
2649
2771
  )
2650
2772
  attention_head_dim = out_channels
@@ -2708,17 +2830,14 @@ class AttnUpDecoderBlock2D(nn.Module):
2708
2830
 
2709
2831
  self.resolution_idx = resolution_idx
2710
2832
 
2711
- def forward(
2712
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
2713
- ) -> torch.FloatTensor:
2833
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2714
2834
  for resnet, attn in zip(self.resnets, self.attentions):
2715
- hidden_states = resnet(hidden_states, temb=temb, scale=scale)
2716
- cross_attention_kwargs = {"scale": scale}
2717
- hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
2835
+ hidden_states = resnet(hidden_states, temb=temb)
2836
+ hidden_states = attn(hidden_states, temb=temb)
2718
2837
 
2719
2838
  if self.upsamplers is not None:
2720
2839
  for upsampler in self.upsamplers:
2721
- hidden_states = upsampler(hidden_states, scale=scale)
2840
+ hidden_states = upsampler(hidden_states)
2722
2841
 
2723
2842
  return hidden_states
2724
2843
 
@@ -2766,7 +2885,7 @@ class AttnSkipUpBlock2D(nn.Module):
2766
2885
  )
2767
2886
 
2768
2887
  if attention_head_dim is None:
2769
- logger.warn(
2888
+ logger.warning(
2770
2889
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
2771
2890
  )
2772
2891
  attention_head_dim = out_channels
@@ -2823,18 +2942,22 @@ class AttnSkipUpBlock2D(nn.Module):
2823
2942
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2824
2943
  temb: Optional[torch.FloatTensor] = None,
2825
2944
  skip_sample=None,
2826
- scale: float = 1.0,
2945
+ *args,
2946
+ **kwargs,
2827
2947
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
2948
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2949
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2950
+ deprecate("scale", "1.0.0", deprecation_message)
2951
+
2828
2952
  for resnet in self.resnets:
2829
2953
  # pop res hidden states
2830
2954
  res_hidden_states = res_hidden_states_tuple[-1]
2831
2955
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2832
2956
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2833
2957
 
2834
- hidden_states = resnet(hidden_states, temb, scale=scale)
2958
+ hidden_states = resnet(hidden_states, temb)
2835
2959
 
2836
- cross_attention_kwargs = {"scale": scale}
2837
- hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
2960
+ hidden_states = self.attentions[0](hidden_states)
2838
2961
 
2839
2962
  if skip_sample is not None:
2840
2963
  skip_sample = self.upsampler(skip_sample)
@@ -2848,7 +2971,7 @@ class AttnSkipUpBlock2D(nn.Module):
2848
2971
 
2849
2972
  skip_sample = skip_sample + skip_sample_states
2850
2973
 
2851
- hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
2974
+ hidden_states = self.resnet_up(hidden_states, temb)
2852
2975
 
2853
2976
  return hidden_states, skip_sample
2854
2977
 
@@ -2931,15 +3054,20 @@ class SkipUpBlock2D(nn.Module):
2931
3054
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2932
3055
  temb: Optional[torch.FloatTensor] = None,
2933
3056
  skip_sample=None,
2934
- scale: float = 1.0,
3057
+ *args,
3058
+ **kwargs,
2935
3059
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
3060
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
3061
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3062
+ deprecate("scale", "1.0.0", deprecation_message)
3063
+
2936
3064
  for resnet in self.resnets:
2937
3065
  # pop res hidden states
2938
3066
  res_hidden_states = res_hidden_states_tuple[-1]
2939
3067
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2940
3068
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2941
3069
 
2942
- hidden_states = resnet(hidden_states, temb, scale=scale)
3070
+ hidden_states = resnet(hidden_states, temb)
2943
3071
 
2944
3072
  if skip_sample is not None:
2945
3073
  skip_sample = self.upsampler(skip_sample)
@@ -2953,7 +3081,7 @@ class SkipUpBlock2D(nn.Module):
2953
3081
 
2954
3082
  skip_sample = skip_sample + skip_sample_states
2955
3083
 
2956
- hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
3084
+ hidden_states = self.resnet_up(hidden_states, temb)
2957
3085
 
2958
3086
  return hidden_states, skip_sample
2959
3087
 
@@ -3033,8 +3161,13 @@ class ResnetUpsampleBlock2D(nn.Module):
3033
3161
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3034
3162
  temb: Optional[torch.FloatTensor] = None,
3035
3163
  upsample_size: Optional[int] = None,
3036
- scale: float = 1.0,
3164
+ *args,
3165
+ **kwargs,
3037
3166
  ) -> torch.FloatTensor:
3167
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
3168
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3169
+ deprecate("scale", "1.0.0", deprecation_message)
3170
+
3038
3171
  for resnet in self.resnets:
3039
3172
  # pop res hidden states
3040
3173
  res_hidden_states = res_hidden_states_tuple[-1]
@@ -3058,11 +3191,11 @@ class ResnetUpsampleBlock2D(nn.Module):
3058
3191
  create_custom_forward(resnet), hidden_states, temb
3059
3192
  )
3060
3193
  else:
3061
- hidden_states = resnet(hidden_states, temb, scale=scale)
3194
+ hidden_states = resnet(hidden_states, temb)
3062
3195
 
3063
3196
  if self.upsamplers is not None:
3064
3197
  for upsampler in self.upsamplers:
3065
- hidden_states = upsampler(hidden_states, temb, scale=scale)
3198
+ hidden_states = upsampler(hidden_states, temb)
3066
3199
 
3067
3200
  return hidden_states
3068
3201
 
@@ -3178,8 +3311,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3178
3311
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
3179
3312
  ) -> torch.FloatTensor:
3180
3313
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3314
+ if cross_attention_kwargs.get("scale", None) is not None:
3315
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3181
3316
 
3182
- lora_scale = cross_attention_kwargs.get("scale", 1.0)
3183
3317
  if attention_mask is None:
3184
3318
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
3185
3319
  mask = None if encoder_hidden_states is None else encoder_attention_mask
@@ -3217,7 +3351,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3217
3351
  **cross_attention_kwargs,
3218
3352
  )
3219
3353
  else:
3220
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
3354
+ hidden_states = resnet(hidden_states, temb)
3221
3355
 
3222
3356
  hidden_states = attn(
3223
3357
  hidden_states,
@@ -3228,7 +3362,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3228
3362
 
3229
3363
  if self.upsamplers is not None:
3230
3364
  for upsampler in self.upsamplers:
3231
- hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
3365
+ hidden_states = upsampler(hidden_states, temb)
3232
3366
 
3233
3367
  return hidden_states
3234
3368
 
@@ -3289,8 +3423,13 @@ class KUpBlock2D(nn.Module):
3289
3423
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3290
3424
  temb: Optional[torch.FloatTensor] = None,
3291
3425
  upsample_size: Optional[int] = None,
3292
- scale: float = 1.0,
3426
+ *args,
3427
+ **kwargs,
3293
3428
  ) -> torch.FloatTensor:
3429
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
3430
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3431
+ deprecate("scale", "1.0.0", deprecation_message)
3432
+
3294
3433
  res_hidden_states_tuple = res_hidden_states_tuple[-1]
3295
3434
  if res_hidden_states_tuple is not None:
3296
3435
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3313,7 +3452,7 @@ class KUpBlock2D(nn.Module):
3313
3452
  create_custom_forward(resnet), hidden_states, temb
3314
3453
  )
3315
3454
  else:
3316
- hidden_states = resnet(hidden_states, temb, scale=scale)
3455
+ hidden_states = resnet(hidden_states, temb)
3317
3456
 
3318
3457
  if self.upsamplers is not None:
3319
3458
  for upsampler in self.upsamplers:
@@ -3423,7 +3562,6 @@ class KCrossAttnUpBlock2D(nn.Module):
3423
3562
  if res_hidden_states_tuple is not None:
3424
3563
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3425
3564
 
3426
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
3427
3565
  for resnet, attn in zip(self.resnets, self.attentions):
3428
3566
  if self.training and self.gradient_checkpointing:
3429
3567
 
@@ -3452,7 +3590,7 @@ class KCrossAttnUpBlock2D(nn.Module):
3452
3590
  encoder_attention_mask=encoder_attention_mask,
3453
3591
  )
3454
3592
  else:
3455
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
3593
+ hidden_states = resnet(hidden_states, temb)
3456
3594
  hidden_states = attn(
3457
3595
  hidden_states,
3458
3596
  encoder_hidden_states=encoder_hidden_states,
@@ -3555,6 +3693,8 @@ class KAttentionBlock(nn.Module):
3555
3693
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
3556
3694
  ) -> torch.FloatTensor:
3557
3695
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3696
+ if cross_attention_kwargs.get("scale", None) is not None:
3697
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3558
3698
 
3559
3699
  # 1. Self-Attention
3560
3700
  if self.add_self_attention: