diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.26.3"
1
+ __version__ = "0.27.0"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -86,6 +86,7 @@ else:
86
86
  "MotionAdapter",
87
87
  "MultiAdapter",
88
88
  "PriorTransformer",
89
+ "StableCascadeUNet",
89
90
  "T2IAdapter",
90
91
  "T5FilmDecoder",
91
92
  "Transformer2DModel",
@@ -128,6 +129,7 @@ else:
128
129
  "PNDMPipeline",
129
130
  "RePaintPipeline",
130
131
  "ScoreSdeVePipeline",
132
+ "StableDiffusionMixin",
131
133
  ]
132
134
  )
133
135
  _import_structure["schedulers"].extend(
@@ -144,6 +146,8 @@ else:
144
146
  "DPMSolverMultistepInverseScheduler",
145
147
  "DPMSolverMultistepScheduler",
146
148
  "DPMSolverSinglestepScheduler",
149
+ "EDMDPMSolverMultistepScheduler",
150
+ "EDMEulerScheduler",
147
151
  "EulerAncestralDiscreteScheduler",
148
152
  "EulerDiscreteScheduler",
149
153
  "HeunDiscreteScheduler",
@@ -157,6 +161,7 @@ else:
157
161
  "SASolverScheduler",
158
162
  "SchedulerMixin",
159
163
  "ScoreSdeVeScheduler",
164
+ "TCDScheduler",
160
165
  "UnCLIPScheduler",
161
166
  "UniPCMultistepScheduler",
162
167
  "VQDiffusionScheduler",
@@ -248,6 +253,8 @@ else:
248
253
  "LatentConsistencyModelImg2ImgPipeline",
249
254
  "LatentConsistencyModelPipeline",
250
255
  "LDMTextToImagePipeline",
256
+ "LEditsPPPipelineStableDiffusion",
257
+ "LEditsPPPipelineStableDiffusionXL",
251
258
  "MusicLDMPipeline",
252
259
  "PaintByExamplePipeline",
253
260
  "PIAPipeline",
@@ -255,6 +262,9 @@ else:
255
262
  "SemanticStableDiffusionPipeline",
256
263
  "ShapEImg2ImgPipeline",
257
264
  "ShapEPipeline",
265
+ "StableCascadeCombinedPipeline",
266
+ "StableCascadeDecoderPipeline",
267
+ "StableCascadePriorPipeline",
258
268
  "StableDiffusionAdapterPipeline",
259
269
  "StableDiffusionAttendAndExcitePipeline",
260
270
  "StableDiffusionControlNetImg2ImgPipeline",
@@ -512,6 +522,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
512
522
  PNDMPipeline,
513
523
  RePaintPipeline,
514
524
  ScoreSdeVePipeline,
525
+ StableDiffusionMixin,
515
526
  )
516
527
  from .schedulers import (
517
528
  AmusedScheduler,
@@ -526,6 +537,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
526
537
  DPMSolverMultistepInverseScheduler,
527
538
  DPMSolverMultistepScheduler,
528
539
  DPMSolverSinglestepScheduler,
540
+ EDMDPMSolverMultistepScheduler,
541
+ EDMEulerScheduler,
529
542
  EulerAncestralDiscreteScheduler,
530
543
  EulerDiscreteScheduler,
531
544
  HeunDiscreteScheduler,
@@ -539,6 +552,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
539
552
  SASolverScheduler,
540
553
  SchedulerMixin,
541
554
  ScoreSdeVeScheduler,
555
+ TCDScheduler,
542
556
  UnCLIPScheduler,
543
557
  UniPCMultistepScheduler,
544
558
  VQDiffusionScheduler,
@@ -611,6 +625,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
611
625
  LatentConsistencyModelImg2ImgPipeline,
612
626
  LatentConsistencyModelPipeline,
613
627
  LDMTextToImagePipeline,
628
+ LEditsPPPipelineStableDiffusion,
629
+ LEditsPPPipelineStableDiffusionXL,
614
630
  MusicLDMPipeline,
615
631
  PaintByExamplePipeline,
616
632
  PIAPipeline,
@@ -618,6 +634,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
618
634
  SemanticStableDiffusionPipeline,
619
635
  ShapEImg2ImgPipeline,
620
636
  ShapEPipeline,
637
+ StableCascadeCombinedPipeline,
638
+ StableCascadeDecoderPipeline,
639
+ StableCascadePriorPipeline,
621
640
  StableDiffusionAdapterPipeline,
622
641
  StableDiffusionAttendAndExcitePipeline,
623
642
  StableDiffusionControlNetImg2ImgPipeline,
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
diffusers/commands/env.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
3
  # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
4
  #
5
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -127,7 +127,7 @@ class ConfigMixin:
127
127
  """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
128
128
  config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
129
129
 
130
- Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
130
+ This function is mostly copied from PyTorch's __getattr__ overwrite:
131
131
  https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
132
132
  """
133
133
 
@@ -259,6 +259,10 @@ class ConfigMixin:
259
259
  model = cls(**init_dict)
260
260
 
261
261
  # make sure to also save config parameters that might be used for compatible classes
262
+ # update _class_name
263
+ if "_class_name" in hidden_dict:
264
+ hidden_dict["_class_name"] = cls.__name__
265
+
262
266
  model.register_to_config(**hidden_dict)
263
267
 
264
268
  # add hidden kwargs of compatible classes to unused_kwargs
@@ -529,7 +533,7 @@ class ConfigMixin:
529
533
  f"{cls.config_name} configuration file."
530
534
  )
531
535
 
532
- # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
536
+ # 5. Give nice info if config attributes are initialized to default because they have not been passed
533
537
  passed_keys = set(init_dict.keys())
534
538
  if len(expected_keys - passed_keys) > 0:
535
539
  logger.info(
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -38,8 +38,8 @@ deps = {
38
38
  "regex": "regex!=2019.12.17",
39
39
  "requests": "requests",
40
40
  "tensorboard": "tensorboard",
41
- "torch": "torch>=1.4,<2.2.0",
42
- "torchvision": "torchvision<0.17",
41
+ "torch": "torch>=1.4",
42
+ "torchvision": "torchvision",
43
43
  "transformers": "transformers>=4.25.1",
44
44
  "urllib3": "urllib3<=2.0.0",
45
45
  }
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,12 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import math
15
16
  import warnings
16
17
  from typing import List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
19
20
  import PIL.Image
20
21
  import torch
22
+ import torch.nn.functional as F
21
23
  from PIL import Image, ImageFilter, ImageOps
22
24
 
23
25
  from .configuration_utils import ConfigMixin, register_to_config
@@ -330,7 +332,7 @@ class VaeImageProcessor(ConfigMixin):
330
332
  image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
331
333
  height: int,
332
334
  width: int,
333
- resize_mode: str = "default", # "defalt", "fill", "crop"
335
+ resize_mode: str = "default", # "default", "fill", "crop"
334
336
  ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
335
337
  """
336
338
  Resize image.
@@ -446,7 +448,7 @@ class VaeImageProcessor(ConfigMixin):
446
448
  image: PipelineImageInput,
447
449
  height: Optional[int] = None,
448
450
  width: Optional[int] = None,
449
- resize_mode: str = "default", # "defalt", "fill", "crop"
451
+ resize_mode: str = "default", # "default", "fill", "crop"
450
452
  crops_coords: Optional[Tuple[int, int, int, int]] = None,
451
453
  ) -> torch.Tensor:
452
454
  """
@@ -477,7 +479,7 @@ class VaeImageProcessor(ConfigMixin):
477
479
  if isinstance(image, torch.Tensor):
478
480
  # if image is a pytorch tensor could have 2 possible shapes:
479
481
  # 1. batch x height x width: we should insert the channel dimension at position 1
480
- # 2. channnel x height x width: we should insert batch dimension at position 0,
482
+ # 2. channel x height x width: we should insert batch dimension at position 0,
481
483
  # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
482
484
  # for simplicity, we insert a dimension of size 1 at position 1 for both cases
483
485
  image = image.unsqueeze(1)
@@ -882,3 +884,107 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
882
884
  depth = self.binarize(depth)
883
885
 
884
886
  return rgb, depth
887
+
888
+
889
+ class IPAdapterMaskProcessor(VaeImageProcessor):
890
+ """
891
+ Image processor for IP Adapter image masks.
892
+
893
+ Args:
894
+ do_resize (`bool`, *optional*, defaults to `True`):
895
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
896
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
897
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
898
+ resample (`str`, *optional*, defaults to `lanczos`):
899
+ Resampling filter to use when resizing the image.
900
+ do_normalize (`bool`, *optional*, defaults to `False`):
901
+ Whether to normalize the image to [-1,1].
902
+ do_binarize (`bool`, *optional*, defaults to `True`):
903
+ Whether to binarize the image to 0/1.
904
+ do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
905
+ Whether to convert the images to grayscale format.
906
+
907
+ """
908
+
909
+ config_name = CONFIG_NAME
910
+
911
+ @register_to_config
912
+ def __init__(
913
+ self,
914
+ do_resize: bool = True,
915
+ vae_scale_factor: int = 8,
916
+ resample: str = "lanczos",
917
+ do_normalize: bool = False,
918
+ do_binarize: bool = True,
919
+ do_convert_grayscale: bool = True,
920
+ ):
921
+ super().__init__(
922
+ do_resize=do_resize,
923
+ vae_scale_factor=vae_scale_factor,
924
+ resample=resample,
925
+ do_normalize=do_normalize,
926
+ do_binarize=do_binarize,
927
+ do_convert_grayscale=do_convert_grayscale,
928
+ )
929
+
930
+ @staticmethod
931
+ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
932
+ """
933
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
934
+ If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
935
+
936
+ Args:
937
+ mask (`torch.FloatTensor`):
938
+ The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
939
+ batch_size (`int`):
940
+ The batch size.
941
+ num_queries (`int`):
942
+ The number of queries.
943
+ value_embed_dim (`int`):
944
+ The dimensionality of the value embeddings.
945
+
946
+ Returns:
947
+ `torch.FloatTensor`:
948
+ The downsampled mask tensor.
949
+
950
+ """
951
+ o_h = mask.shape[1]
952
+ o_w = mask.shape[2]
953
+ ratio = o_w / o_h
954
+ mask_h = int(math.sqrt(num_queries / ratio))
955
+ mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
956
+ mask_w = num_queries // mask_h
957
+
958
+ mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
959
+
960
+ # Repeat batch_size times
961
+ if mask_downsample.shape[0] < batch_size:
962
+ mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
963
+
964
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
965
+
966
+ downsampled_area = mask_h * mask_w
967
+ # If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
968
+ # Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
969
+ if downsampled_area < num_queries:
970
+ warnings.warn(
971
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
972
+ "Please update your masks or adjust the output size for optimal performance.",
973
+ UserWarning,
974
+ )
975
+ mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
976
+ # Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
977
+ if downsampled_area > num_queries:
978
+ warnings.warn(
979
+ "The aspect ratio of the mask does not match the aspect ratio of the output image. "
980
+ "Please update your masks or adjust the output size for optimal performance.",
981
+ UserWarning,
982
+ )
983
+ mask_downsample = mask_downsample[:, :num_queries]
984
+
985
+ # Repeat last dimension to match SDPA output shape
986
+ mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
987
+ 1, 1, value_embed_dim
988
+ )
989
+
990
+ return mask_downsample
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -75,10 +75,6 @@ class FromOriginalVAEMixin:
75
75
  diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
76
76
  = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
77
77
  Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
78
- use_safetensors (`bool`, *optional*, defaults to `None`):
79
- If set to `None`, the safetensors weights are downloaded if they're available **and** if the
80
- safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
81
- weights. If set to `False`, safetensors weights are not loaded.
82
78
  kwargs (remaining dictionary of keyword arguments, *optional*):
83
79
  Can be used to overwrite load and saveable variables (for example the pipeline components of the
84
80
  specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
@@ -111,7 +107,6 @@ class FromOriginalVAEMixin:
111
107
  local_files_only = kwargs.pop("local_files_only", None)
112
108
  revision = kwargs.pop("revision", None)
113
109
  torch_dtype = kwargs.pop("torch_dtype", None)
114
- use_safetensors = kwargs.pop("use_safetensors", True)
115
110
 
116
111
  class_name = cls.__name__
117
112
 
@@ -131,14 +126,18 @@ class FromOriginalVAEMixin:
131
126
  token=token,
132
127
  revision=revision,
133
128
  local_files_only=local_files_only,
134
- use_safetensors=use_safetensors,
135
129
  cache_dir=cache_dir,
136
130
  )
137
131
 
138
132
  image_size = kwargs.pop("image_size", None)
139
133
  scaling_factor = kwargs.pop("scaling_factor", None)
140
134
  component = create_diffusers_vae_model_from_ldm(
141
- class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
135
+ class_name,
136
+ original_config,
137
+ checkpoint,
138
+ image_size=image_size,
139
+ scaling_factor=scaling_factor,
140
+ torch_dtype=torch_dtype,
142
141
  )
143
142
  vae = component["vae"]
144
143
  if torch_dtype is not None:
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -38,6 +38,9 @@ class FromOriginalControlNetMixin:
38
38
  - A link to the `.ckpt` file (for example
39
39
  `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
40
40
  - A path to a *file* containing all pipeline weights.
41
+ config_file (`str`, *optional*):
42
+ Filepath to the configuration YAML file associated with the model. If not provided it will default to:
43
+ https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
41
44
  torch_dtype (`str` or `torch.dtype`, *optional*):
42
45
  Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
43
46
  dtype is automatically derived from the model's weights.
@@ -62,10 +65,6 @@ class FromOriginalControlNetMixin:
62
65
  revision (`str`, *optional*, defaults to `"main"`):
63
66
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
64
67
  allowed by Git.
65
- use_safetensors (`bool`, *optional*, defaults to `None`):
66
- If set to `None`, the safetensors weights are downloaded if they're available **and** if the
67
- safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
68
- weights. If set to `False`, safetensors weights are not loaded.
69
68
  image_size (`int`, *optional*, defaults to 512):
70
69
  The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
71
70
  Diffusion v2 base model. Use 768 for Stable Diffusion v2.
@@ -89,6 +88,7 @@ class FromOriginalControlNetMixin:
89
88
  ```
90
89
  """
91
90
  original_config_file = kwargs.pop("original_config_file", None)
91
+ config_file = kwargs.pop("config_file", None)
92
92
  resume_download = kwargs.pop("resume_download", False)
93
93
  force_download = kwargs.pop("force_download", False)
94
94
  proxies = kwargs.pop("proxies", None)
@@ -97,9 +97,14 @@ class FromOriginalControlNetMixin:
97
97
  local_files_only = kwargs.pop("local_files_only", None)
98
98
  revision = kwargs.pop("revision", None)
99
99
  torch_dtype = kwargs.pop("torch_dtype", None)
100
- use_safetensors = kwargs.pop("use_safetensors", True)
101
100
 
102
101
  class_name = cls.__name__
102
+ if (config_file is not None) and (original_config_file is not None):
103
+ raise ValueError(
104
+ "You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
105
+ )
106
+
107
+ original_config_file = config_file or original_config_file
103
108
  original_config, checkpoint = fetch_ldm_config_and_checkpoint(
104
109
  pretrained_model_link_or_path=pretrained_model_link_or_path,
105
110
  class_name=class_name,
@@ -110,7 +115,6 @@ class FromOriginalControlNetMixin:
110
115
  token=token,
111
116
  revision=revision,
112
117
  local_files_only=local_files_only,
113
- use_safetensors=use_safetensors,
114
118
  cache_dir=cache_dir,
115
119
  )
116
120
 
@@ -118,7 +122,12 @@ class FromOriginalControlNetMixin:
118
122
  image_size = kwargs.pop("image_size", None)
119
123
 
120
124
  component = create_diffusers_controlnet_model_from_ldm(
121
- class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
125
+ class_name,
126
+ original_config,
127
+ checkpoint,
128
+ upcast_attention=upcast_attention,
129
+ image_size=image_size,
130
+ torch_dtype=torch_dtype,
122
131
  )
123
132
  controlnet = component["controlnet"]
124
133
  if torch_dtype is not None:
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,14 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import Dict, List, Union
16
+ from typing import Dict, List, Optional, Union
17
17
 
18
18
  import torch
19
19
  from huggingface_hub.utils import validate_hf_hub_args
20
20
  from safetensors import safe_open
21
21
 
22
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
22
23
  from ..utils import (
23
24
  _get_model_file,
25
+ is_accelerate_available,
26
+ is_torch_version,
24
27
  is_transformers_available,
25
28
  logging,
26
29
  )
@@ -49,11 +52,12 @@ class IPAdapterMixin:
49
52
  pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
50
53
  subfolder: Union[str, List[str]],
51
54
  weight_name: Union[str, List[str]],
55
+ image_encoder_folder: Optional[str] = "image_encoder",
52
56
  **kwargs,
53
57
  ):
54
58
  """
55
59
  Parameters:
56
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
60
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
57
61
  Can be either:
58
62
 
59
63
  - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
@@ -62,7 +66,18 @@ class IPAdapterMixin:
62
66
  with [`ModelMixin.save_pretrained`].
63
67
  - A [torch state
64
68
  dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
65
-
69
+ subfolder (`str` or `List[str]`):
70
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
71
+ If a list is passed, it should have the same length as `weight_name`.
72
+ weight_name (`str` or `List[str]`):
73
+ The name of the weight file to load. If a list is passed, it should have the same length as
74
+ `weight_name`.
75
+ image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
76
+ The subfolder location of the image encoder within a larger model repository on the Hub or locally.
77
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
78
+ you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
79
+ If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
80
+ for example, `image_encoder_folder="different_subfolder/image_encoder"`.
66
81
  cache_dir (`Union[str, os.PathLike]`, *optional*):
67
82
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
68
83
  is not used.
@@ -84,8 +99,11 @@ class IPAdapterMixin:
84
99
  revision (`str`, *optional*, defaults to `"main"`):
85
100
  The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
86
101
  allowed by Git.
87
- subfolder (`str`, *optional*, defaults to `""`):
88
- The subfolder location of a model file within a larger model repository on the Hub or locally.
102
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
103
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
104
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
105
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
106
+ argument to `True` will raise an error.
89
107
  """
90
108
 
91
109
  # handle the list inputs for multiple IP Adapters
@@ -116,6 +134,22 @@ class IPAdapterMixin:
116
134
  local_files_only = kwargs.pop("local_files_only", None)
117
135
  token = kwargs.pop("token", None)
118
136
  revision = kwargs.pop("revision", None)
137
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
138
+
139
+ if low_cpu_mem_usage and not is_accelerate_available():
140
+ low_cpu_mem_usage = False
141
+ logger.warning(
142
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
143
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
144
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
145
+ " install accelerate\n```\n."
146
+ )
147
+
148
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
149
+ raise NotImplementedError(
150
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
151
+ " `low_cpu_mem_usage=False`."
152
+ )
119
153
 
120
154
  user_agent = {
121
155
  "file_type": "attn_procs_weights",
@@ -160,32 +194,59 @@ class IPAdapterMixin:
160
194
 
161
195
  # load CLIP image encoder here if it has not been registered to the pipeline yet
162
196
  if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
163
- if not isinstance(pretrained_model_name_or_path_or_dict, dict):
164
- logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
165
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
166
- pretrained_model_name_or_path_or_dict,
167
- subfolder=Path(subfolder, "image_encoder").as_posix(),
168
- ).to(self.device, dtype=self.dtype)
169
- self.image_encoder = image_encoder
170
- self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
197
+ if image_encoder_folder is not None:
198
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
199
+ logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
200
+ if image_encoder_folder.count("/") == 0:
201
+ image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
202
+ else:
203
+ image_encoder_subfolder = Path(image_encoder_folder).as_posix()
204
+
205
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
206
+ pretrained_model_name_or_path_or_dict,
207
+ subfolder=image_encoder_subfolder,
208
+ low_cpu_mem_usage=low_cpu_mem_usage,
209
+ ).to(self.device, dtype=self.dtype)
210
+ self.register_modules(image_encoder=image_encoder)
211
+ else:
212
+ raise ValueError(
213
+ "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
214
+ )
171
215
  else:
172
- raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
216
+ logger.warning(
217
+ "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
218
+ "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
219
+ )
173
220
 
174
221
  # create feature extractor if it has not been registered to the pipeline yet
175
222
  if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
176
223
  feature_extractor = CLIPImageProcessor()
177
224
  self.register_modules(feature_extractor=feature_extractor)
178
225
 
179
- # load ip-adapter into unet
226
+ # load ip-adapter into unet
180
227
  unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
181
- unet._load_ip_adapter_weights(state_dicts)
228
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
182
229
 
183
230
  def set_ip_adapter_scale(self, scale):
184
- if not isinstance(scale, list):
185
- scale = [scale]
231
+ """
232
+ Sets the conditioning scale between text and image.
233
+
234
+ Example:
235
+
236
+ ```py
237
+ pipeline.set_ip_adapter_scale(0.5)
238
+ ```
239
+ """
186
240
  unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
187
241
  for attn_processor in unet.attn_processors.values():
188
242
  if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
243
+ if not isinstance(scale, list):
244
+ scale = [scale] * len(attn_processor.scale)
245
+ if len(attn_processor.scale) != len(scale):
246
+ raise ValueError(
247
+ f"`scale` should be a list of same length as the number if ip-adapters "
248
+ f"Expected {len(attn_processor.scale)} but got {len(scale)}."
249
+ )
189
250
  attn_processor.scale = scale
190
251
 
191
252
  def unload_ip_adapter(self):
@@ -205,10 +266,12 @@ class IPAdapterMixin:
205
266
  self.image_encoder = None
206
267
  self.register_to_config(image_encoder=[None, None])
207
268
 
208
- # remove feature extractor
209
- if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
210
- self.feature_extractor = None
211
- self.register_to_config(feature_extractor=[None, None])
269
+ # remove feature extractor only when safety_checker is None as safety_checker uses
270
+ # the feature_extractor later
271
+ if not hasattr(self, "safety_checker"):
272
+ if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
273
+ self.feature_extractor = None
274
+ self.register_to_config(feature_extractor=[None, None])
212
275
 
213
276
  # remove hidden encoder
214
277
  self.unet.encoder_hid_proj = None