diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/unet.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -37,10 +37,16 @@ from ..utils import (
37
37
  _get_model_file,
38
38
  delete_adapter_layers,
39
39
  is_accelerate_available,
40
+ is_torch_version,
40
41
  logging,
41
42
  set_adapter_layers,
42
43
  set_weights_and_activate_adapters,
43
44
  )
45
+ from .single_file_utils import (
46
+ convert_stable_cascade_unet_single_file_to_diffusers,
47
+ infer_stable_cascade_single_file_config,
48
+ load_single_file_model_checkpoint,
49
+ )
44
50
  from .utils import AttnProcsLayers
45
51
 
46
52
 
@@ -168,15 +174,6 @@ class UNet2DConditionLoadersMixin:
168
174
  "framework": "pytorch",
169
175
  }
170
176
 
171
- if low_cpu_mem_usage and not is_accelerate_available():
172
- low_cpu_mem_usage = False
173
- logger.warning(
174
- "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
175
- " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
176
- " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
177
- " install accelerate\n```\n."
178
- )
179
-
180
177
  model_file = None
181
178
  if not isinstance(pretrained_model_name_or_path_or_dict, dict):
182
179
  # Let's first try to load .safetensors weights
@@ -353,7 +350,7 @@ class UNet2DConditionLoadersMixin:
353
350
  is_model_cpu_offload = False
354
351
  is_sequential_cpu_offload = False
355
352
 
356
- # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
353
+ # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
357
354
  if not USE_PEFT_BACKEND:
358
355
  if _pipeline is not None:
359
356
  for _, component in _pipeline.components.items():
@@ -392,7 +389,7 @@ class UNet2DConditionLoadersMixin:
392
389
  is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
393
390
  if is_text_encoder_present:
394
391
  warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
395
- logger.warn(warn_message)
392
+ logger.warning(warn_message)
396
393
  unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
397
394
  state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
398
395
 
@@ -694,9 +691,29 @@ class UNet2DConditionLoadersMixin:
694
691
  if hasattr(self, "peft_config"):
695
692
  self.peft_config.pop(adapter_name, None)
696
693
 
697
- def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
694
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
695
+ if low_cpu_mem_usage:
696
+ if is_accelerate_available():
697
+ from accelerate import init_empty_weights
698
+
699
+ else:
700
+ low_cpu_mem_usage = False
701
+ logger.warning(
702
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
703
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
704
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
705
+ " install accelerate\n```\n."
706
+ )
707
+
708
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
709
+ raise NotImplementedError(
710
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
711
+ " `low_cpu_mem_usage=False`."
712
+ )
713
+
698
714
  updated_state_dict = {}
699
715
  image_projection = None
716
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
700
717
 
701
718
  if "proj.weight" in state_dict:
702
719
  # IP-Adapter
@@ -704,11 +721,12 @@ class UNet2DConditionLoadersMixin:
704
721
  clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
705
722
  cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
706
723
 
707
- image_projection = ImageProjection(
708
- cross_attention_dim=cross_attention_dim,
709
- image_embed_dim=clip_embeddings_dim,
710
- num_image_text_embeds=num_image_text_embeds,
711
- )
724
+ with init_context():
725
+ image_projection = ImageProjection(
726
+ cross_attention_dim=cross_attention_dim,
727
+ image_embed_dim=clip_embeddings_dim,
728
+ num_image_text_embeds=num_image_text_embeds,
729
+ )
712
730
 
713
731
  for key, value in state_dict.items():
714
732
  diffusers_name = key.replace("proj", "image_embeds")
@@ -719,9 +737,10 @@ class UNet2DConditionLoadersMixin:
719
737
  clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
720
738
  cross_attention_dim = state_dict["proj.3.weight"].shape[0]
721
739
 
722
- image_projection = IPAdapterFullImageProjection(
723
- cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
724
- )
740
+ with init_context():
741
+ image_projection = IPAdapterFullImageProjection(
742
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
743
+ )
725
744
 
726
745
  for key, value in state_dict.items():
727
746
  diffusers_name = key.replace("proj.0", "ff.net.0.proj")
@@ -737,13 +756,14 @@ class UNet2DConditionLoadersMixin:
737
756
  hidden_dims = state_dict["latents"].shape[2]
738
757
  heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
739
758
 
740
- image_projection = IPAdapterPlusImageProjection(
741
- embed_dims=embed_dims,
742
- output_dims=output_dims,
743
- hidden_dims=hidden_dims,
744
- heads=heads,
745
- num_queries=num_image_text_embeds,
746
- )
759
+ with init_context():
760
+ image_projection = IPAdapterPlusImageProjection(
761
+ embed_dims=embed_dims,
762
+ output_dims=output_dims,
763
+ hidden_dims=hidden_dims,
764
+ heads=heads,
765
+ num_queries=num_image_text_embeds,
766
+ )
747
767
 
748
768
  for key, value in state_dict.items():
749
769
  diffusers_name = key.replace("0.to", "2.to")
@@ -765,10 +785,14 @@ class UNet2DConditionLoadersMixin:
765
785
  else:
766
786
  updated_state_dict[diffusers_name] = value
767
787
 
768
- image_projection.load_state_dict(updated_state_dict)
788
+ if not low_cpu_mem_usage:
789
+ image_projection.load_state_dict(updated_state_dict)
790
+ else:
791
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
792
+
769
793
  return image_projection
770
794
 
771
- def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
795
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
772
796
  from ..models.attention_processor import (
773
797
  AttnProcessor,
774
798
  AttnProcessor2_0,
@@ -776,9 +800,29 @@ class UNet2DConditionLoadersMixin:
776
800
  IPAdapterAttnProcessor2_0,
777
801
  )
778
802
 
803
+ if low_cpu_mem_usage:
804
+ if is_accelerate_available():
805
+ from accelerate import init_empty_weights
806
+
807
+ else:
808
+ low_cpu_mem_usage = False
809
+ logger.warning(
810
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
811
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
812
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
813
+ " install accelerate\n```\n."
814
+ )
815
+
816
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
817
+ raise NotImplementedError(
818
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
819
+ " `low_cpu_mem_usage=False`."
820
+ )
821
+
779
822
  # set ip-adapter cross-attention processors & load state_dict
780
823
  attn_procs = {}
781
824
  key_id = 1
825
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
782
826
  for name in self.attn_processors.keys():
783
827
  cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
784
828
  if name.startswith("mid_block"):
@@ -811,39 +855,149 @@ class UNet2DConditionLoadersMixin:
811
855
  # IP-Adapter Plus
812
856
  num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
813
857
 
814
- attn_procs[name] = attn_processor_class(
815
- hidden_size=hidden_size,
816
- cross_attention_dim=cross_attention_dim,
817
- scale=1.0,
818
- num_tokens=num_image_text_embeds,
819
- ).to(dtype=self.dtype, device=self.device)
858
+ with init_context():
859
+ attn_procs[name] = attn_processor_class(
860
+ hidden_size=hidden_size,
861
+ cross_attention_dim=cross_attention_dim,
862
+ scale=1.0,
863
+ num_tokens=num_image_text_embeds,
864
+ )
820
865
 
821
866
  value_dict = {}
822
867
  for i, state_dict in enumerate(state_dicts):
823
868
  value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
824
869
  value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
825
870
 
826
- attn_procs[name].load_state_dict(value_dict)
871
+ if not low_cpu_mem_usage:
872
+ attn_procs[name].load_state_dict(value_dict)
873
+ else:
874
+ device = next(iter(value_dict.values())).device
875
+ dtype = next(iter(value_dict.values())).dtype
876
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
877
+
827
878
  key_id += 2
828
879
 
829
880
  return attn_procs
830
881
 
831
- def _load_ip_adapter_weights(self, state_dicts):
882
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
832
883
  if not isinstance(state_dicts, list):
833
884
  state_dicts = [state_dicts]
834
885
  # Set encoder_hid_proj after loading ip_adapter weights,
835
886
  # because `IPAdapterPlusImageProjection` also has `attn_processors`.
836
887
  self.encoder_hid_proj = None
837
888
 
838
- attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
889
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
839
890
  self.set_attn_processor(attn_procs)
840
891
 
841
892
  # convert IP-Adapter Image Projection layers to diffusers
842
893
  image_projection_layers = []
843
894
  for state_dict in state_dicts:
844
- image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
845
- image_projection_layer.to(device=self.device, dtype=self.dtype)
895
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
896
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
897
+ )
846
898
  image_projection_layers.append(image_projection_layer)
847
899
 
848
900
  self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
849
901
  self.config.encoder_hid_dim_type = "ip_image_proj"
902
+
903
+ self.to(dtype=self.dtype, device=self.device)
904
+
905
+
906
+ class FromOriginalUNetMixin:
907
+ """
908
+ Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
909
+ """
910
+
911
+ @classmethod
912
+ @validate_hf_hub_args
913
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
914
+ r"""
915
+ Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
916
+ `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
917
+
918
+ Parameters:
919
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
920
+ Can be either:
921
+ - A link to the `.ckpt` file (for example
922
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
923
+ - A path to a *file* containing all pipeline weights.
924
+ config: (`dict`, *optional*):
925
+ Dictionary containing the configuration of the model:
926
+ torch_dtype (`str` or `torch.dtype`, *optional*):
927
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
928
+ dtype is automatically derived from the model's weights.
929
+ force_download (`bool`, *optional*, defaults to `False`):
930
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
931
+ cached versions if they exist.
932
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
933
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
934
+ is not used.
935
+ resume_download (`bool`, *optional*, defaults to `False`):
936
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
937
+ incompletely downloaded files are deleted.
938
+ proxies (`Dict[str, str]`, *optional*):
939
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
940
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
941
+ local_files_only (`bool`, *optional*, defaults to `False`):
942
+ Whether to only load local model weights and configuration files or not. If set to True, the model
943
+ won't be downloaded from the Hub.
944
+ token (`str` or *bool*, *optional*):
945
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
946
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
947
+ revision (`str`, *optional*, defaults to `"main"`):
948
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
949
+ allowed by Git.
950
+ kwargs (remaining dictionary of keyword arguments, *optional*):
951
+ Can be used to overwrite load and saveable variables of the model.
952
+
953
+ """
954
+ class_name = cls.__name__
955
+ if class_name != "StableCascadeUNet":
956
+ raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
957
+
958
+ config = kwargs.pop("config", None)
959
+ resume_download = kwargs.pop("resume_download", False)
960
+ force_download = kwargs.pop("force_download", False)
961
+ proxies = kwargs.pop("proxies", None)
962
+ token = kwargs.pop("token", None)
963
+ cache_dir = kwargs.pop("cache_dir", None)
964
+ local_files_only = kwargs.pop("local_files_only", None)
965
+ revision = kwargs.pop("revision", None)
966
+ torch_dtype = kwargs.pop("torch_dtype", None)
967
+
968
+ checkpoint = load_single_file_model_checkpoint(
969
+ pretrained_model_link_or_path,
970
+ resume_download=resume_download,
971
+ force_download=force_download,
972
+ proxies=proxies,
973
+ token=token,
974
+ cache_dir=cache_dir,
975
+ local_files_only=local_files_only,
976
+ revision=revision,
977
+ )
978
+
979
+ if config is None:
980
+ config = infer_stable_cascade_single_file_config(checkpoint)
981
+ model_config = cls.load_config(**config, **kwargs)
982
+ else:
983
+ model_config = config
984
+
985
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
986
+ with ctx():
987
+ model = cls.from_config(model_config, **kwargs)
988
+
989
+ diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
990
+ if is_accelerate_available():
991
+ unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
992
+ if len(unexpected_keys) > 0:
993
+ logger.warn(
994
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
995
+ )
996
+
997
+ else:
998
+ model.load_state_dict(diffusers_format_checkpoint)
999
+
1000
+ if torch_dtype is not None:
1001
+ model.to(torch_dtype)
1002
+
1003
+ return model
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -47,6 +47,7 @@ if is_torch_available():
47
47
  _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
48
48
  _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
49
49
  _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
50
+ _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
50
51
  _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
51
52
  _import_structure["vq_model"] = ["VQModel"]
52
53
 
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
80
81
  I2VGenXLUNet,
81
82
  Kandinsky3UNet,
82
83
  MotionAdapter,
84
+ StableCascadeUNet,
83
85
  UNet1DModel,
84
86
  UNet2DConditionModel,
85
87
  UNet2DModel,
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 HuggingFace Inc.
2
+ # Copyright 2024 HuggingFace Inc.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -17,8 +17,7 @@ import torch
17
17
  import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
- from ..utils import USE_PEFT_BACKEND
21
- from .lora import LoRACompatibleLinear
20
+ from ..utils import deprecate
22
21
 
23
22
 
24
23
  ACTIVATION_FUNCTIONS = {
@@ -87,9 +86,7 @@ class GEGLU(nn.Module):
87
86
 
88
87
  def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
89
88
  super().__init__()
90
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
91
-
92
- self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
89
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
93
90
 
94
91
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
95
92
  if gate.device.type != "mps":
@@ -97,9 +94,12 @@ class GEGLU(nn.Module):
97
94
  # mps: gelu is not implemented for float16
98
95
  return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
99
96
 
100
- def forward(self, hidden_states, scale: float = 1.0):
101
- args = () if USE_PEFT_BACKEND else (scale,)
102
- hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
97
+ def forward(self, hidden_states, *args, **kwargs):
98
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
99
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
100
+ deprecate("scale", "1.0.0", deprecation_message)
101
+
102
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
103
103
  return hidden_states * self.gelu(gate)
104
104
 
105
105
 
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,18 +17,18 @@ import torch
17
17
  import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
- from ..utils import USE_PEFT_BACKEND
20
+ from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
22
  from .activations import GEGLU, GELU, ApproximateGELU
23
23
  from .attention_processor import Attention
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
- from .lora import LoRACompatibleLinear
26
25
  from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
27
26
 
28
27
 
29
- def _chunked_feed_forward(
30
- ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
31
- ):
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
32
  # "feed_forward_chunk_size" can be used to save memory
33
33
  if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
34
  raise ValueError(
@@ -36,18 +36,10 @@ def _chunked_feed_forward(
36
36
  )
37
37
 
38
38
  num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
- if lora_scale is None:
40
- ff_output = torch.cat(
41
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
- dim=chunk_dim,
43
- )
44
- else:
45
- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
46
- ff_output = torch.cat(
47
- [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48
- dim=chunk_dim,
49
- )
50
-
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
51
43
  return ff_output
52
44
 
53
45
 
@@ -143,7 +135,7 @@ class BasicTransformerBlock(nn.Module):
143
135
  double_self_attention: bool = False,
144
136
  upcast_attention: bool = False,
145
137
  norm_elementwise_affine: bool = True,
146
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen'
138
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
147
139
  norm_eps: float = 1e-5,
148
140
  final_dropout: bool = False,
149
141
  attention_type: str = "default",
@@ -158,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
158
150
  super().__init__()
159
151
  self.only_cross_attention = only_cross_attention
160
152
 
153
+ # We keep these boolean flags for backward-compatibility.
161
154
  self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162
155
  self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163
156
  self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
@@ -298,6 +291,10 @@ class BasicTransformerBlock(nn.Module):
298
291
  class_labels: Optional[torch.LongTensor] = None,
299
292
  added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
300
293
  ) -> torch.FloatTensor:
294
+ if cross_attention_kwargs is not None:
295
+ if cross_attention_kwargs.get("scale", None) is not None:
296
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
297
+
301
298
  # Notice that normalization is always applied before the real computation in the following blocks.
302
299
  # 0. Self-Attention
303
300
  batch_size = hidden_states.shape[0]
@@ -325,10 +322,7 @@ class BasicTransformerBlock(nn.Module):
325
322
  if self.pos_embed is not None:
326
323
  norm_hidden_states = self.pos_embed(norm_hidden_states)
327
324
 
328
- # 1. Retrieve lora scale.
329
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
330
-
331
- # 2. Prepare GLIGEN inputs
325
+ # 1. Prepare GLIGEN inputs
332
326
  cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
333
327
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
334
328
 
@@ -347,7 +341,7 @@ class BasicTransformerBlock(nn.Module):
347
341
  if hidden_states.ndim == 4:
348
342
  hidden_states = hidden_states.squeeze(1)
349
343
 
350
- # 2.5 GLIGEN Control
344
+ # 1.2 GLIGEN Control
351
345
  if gligen_kwargs is not None:
352
346
  hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
353
347
 
@@ -393,11 +387,9 @@ class BasicTransformerBlock(nn.Module):
393
387
 
394
388
  if self._chunk_size is not None:
395
389
  # "feed_forward_chunk_size" can be used to save memory
396
- ff_output = _chunked_feed_forward(
397
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
398
- )
390
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
399
391
  else:
400
- ff_output = self.ff(norm_hidden_states, scale=lora_scale)
392
+ ff_output = self.ff(norm_hidden_states)
401
393
 
402
394
  if self.norm_type == "ada_norm_zero":
403
395
  ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -439,7 +431,6 @@ class TemporalBasicTransformerBlock(nn.Module):
439
431
 
440
432
  # Define 3 blocks. Each block has its own normalization layer.
441
433
  # 1. Self-Attn
442
- self.norm_in = nn.LayerNorm(dim)
443
434
  self.ff_in = FeedForward(
444
435
  dim,
445
436
  dim_out=time_mix_inner_dim,
@@ -643,7 +634,7 @@ class FeedForward(nn.Module):
643
634
  if inner_dim is None:
644
635
  inner_dim = int(dim * mult)
645
636
  dim_out = dim_out if dim_out is not None else dim
646
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
637
+ linear_cls = nn.Linear
647
638
 
648
639
  if activation_fn == "gelu":
649
640
  act_fn = GELU(dim, inner_dim, bias=bias)
@@ -665,11 +656,10 @@ class FeedForward(nn.Module):
665
656
  if final_dropout:
666
657
  self.net.append(nn.Dropout(dropout))
667
658
 
668
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
669
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
659
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
660
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
661
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
662
+ deprecate("scale", "1.0.0", deprecation_message)
670
663
  for module in self.net:
671
- if isinstance(module, compatible_cls):
672
- hidden_states = module(hidden_states, scale)
673
- else:
674
- hidden_states = module(hidden_states)
664
+ hidden_states = module(hidden_states)
675
665
  return hidden_states
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.