diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import torch
18
18
  from huggingface_hub.utils import validate_hf_hub_args
19
19
  from torch import nn
20
20
 
21
+ from ..models.modeling_utils import load_state_dict
21
22
  from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
22
23
 
23
24
 
@@ -37,7 +38,7 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
37
38
  def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
38
39
  cache_dir = kwargs.pop("cache_dir", None)
39
40
  force_download = kwargs.pop("force_download", False)
40
- resume_download = kwargs.pop("resume_download", False)
41
+ resume_download = kwargs.pop("resume_download", None)
41
42
  proxies = kwargs.pop("proxies", None)
42
43
  local_files_only = kwargs.pop("local_files_only", None)
43
44
  token = kwargs.pop("token", None)
@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
100
101
  subfolder=subfolder,
101
102
  user_agent=user_agent,
102
103
  )
103
- state_dict = torch.load(model_file, map_location="cpu")
104
+ state_dict = load_state_dict(model_file)
104
105
  else:
105
106
  state_dict = pretrained_model_name_or_path
106
107
 
@@ -307,9 +308,9 @@ class TextualInversionLoaderMixin:
307
308
  force_download (`bool`, *optional*, defaults to `False`):
308
309
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
309
310
  cached versions if they exist.
310
- resume_download (`bool`, *optional*, defaults to `False`):
311
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
312
- incompletely downloaded files are deleted.
311
+ resume_download:
312
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
313
+ of Diffusers.
313
314
  proxies (`Dict[str, str]`, *optional*):
314
315
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
315
316
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -418,15 +419,20 @@ class TextualInversionLoaderMixin:
418
419
  # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
419
420
  is_model_cpu_offload = False
420
421
  is_sequential_cpu_offload = False
421
- for _, component in self.components.items():
422
- if isinstance(component, nn.Module):
423
- if hasattr(component, "_hf_hook"):
424
- is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
425
- is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
426
- logger.info(
427
- "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
428
- )
429
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
422
+ if self.hf_device_map is None:
423
+ for _, component in self.components.items():
424
+ if isinstance(component, nn.Module):
425
+ if hasattr(component, "_hf_hook"):
426
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
427
+ is_sequential_cpu_offload = (
428
+ isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
429
+ or hasattr(component._hf_hook, "hooks")
430
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
431
+ )
432
+ logger.info(
433
+ "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
434
+ )
435
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
430
436
 
431
437
  # 7.2 save expected device and dtype
432
438
  device = text_encoder.device
@@ -486,20 +492,35 @@ class TextualInversionLoaderMixin:
486
492
 
487
493
  # Example 3: unload from SDXL
488
494
  pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
489
- embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
495
+ embedding_path = hf_hub_download(
496
+ repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
497
+ )
490
498
 
491
499
  # load embeddings to the text encoders
492
500
  state_dict = load_file(embedding_path)
493
501
 
494
502
  # load embeddings of text_encoder 1 (CLIP ViT-L/14)
495
- pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
503
+ pipeline.load_textual_inversion(
504
+ state_dict["clip_l"],
505
+ token=["<s0>", "<s1>"],
506
+ text_encoder=pipeline.text_encoder,
507
+ tokenizer=pipeline.tokenizer,
508
+ )
496
509
  # load embeddings of text_encoder 2 (CLIP ViT-G/14)
497
- pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
510
+ pipeline.load_textual_inversion(
511
+ state_dict["clip_g"],
512
+ token=["<s0>", "<s1>"],
513
+ text_encoder=pipeline.text_encoder_2,
514
+ tokenizer=pipeline.tokenizer_2,
515
+ )
498
516
 
499
517
  # Unload explicitly from both text encoders abd tokenizers
500
- pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
501
- pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
502
-
518
+ pipeline.unload_textual_inversion(
519
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
520
+ )
521
+ pipeline.unload_textual_inversion(
522
+ tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
523
+ )
503
524
  ```
504
525
  """
505
526
 
diffusers/loaders/unet.py CHANGED
@@ -27,11 +27,13 @@ from torch import nn
27
27
 
28
28
  from ..models.embeddings import (
29
29
  ImageProjection,
30
+ IPAdapterFaceIDImageProjection,
31
+ IPAdapterFaceIDPlusImageProjection,
30
32
  IPAdapterFullImageProjection,
31
33
  IPAdapterPlusImageProjection,
32
34
  MultiIPAdapterImageProjection,
33
35
  )
34
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
36
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
35
37
  from ..utils import (
36
38
  USE_PEFT_BACKEND,
37
39
  _get_model_file,
@@ -42,11 +44,7 @@ from ..utils import (
42
44
  set_adapter_layers,
43
45
  set_weights_and_activate_adapters,
44
46
  )
45
- from .single_file_utils import (
46
- convert_stable_cascade_unet_single_file_to_diffusers,
47
- infer_stable_cascade_single_file_config,
48
- load_single_file_model_checkpoint,
49
- )
47
+ from .unet_loader_utils import _maybe_expand_lora_scales
50
48
  from .utils import AttnProcsLayers
51
49
 
52
50
 
@@ -100,9 +98,9 @@ class UNet2DConditionLoadersMixin:
100
98
  force_download (`bool`, *optional*, defaults to `False`):
101
99
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
102
100
  cached versions if they exist.
103
- resume_download (`bool`, *optional*, defaults to `False`):
104
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
105
- incompletely downloaded files are deleted.
101
+ resume_download:
102
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
103
+ of Diffusers.
106
104
  proxies (`Dict[str, str]`, *optional*):
107
105
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
108
106
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -146,7 +144,7 @@ class UNet2DConditionLoadersMixin:
146
144
 
147
145
  cache_dir = kwargs.pop("cache_dir", None)
148
146
  force_download = kwargs.pop("force_download", False)
149
- resume_download = kwargs.pop("resume_download", False)
147
+ resume_download = kwargs.pop("resume_download", None)
150
148
  proxies = kwargs.pop("proxies", None)
151
149
  local_files_only = kwargs.pop("local_files_only", None)
152
150
  token = kwargs.pop("token", None)
@@ -214,7 +212,7 @@ class UNet2DConditionLoadersMixin:
214
212
  subfolder=subfolder,
215
213
  user_agent=user_agent,
216
214
  )
217
- state_dict = torch.load(model_file, map_location="cpu")
215
+ state_dict = load_state_dict(model_file)
218
216
  else:
219
217
  state_dict = pretrained_model_name_or_path_or_dict
220
218
 
@@ -356,7 +354,11 @@ class UNet2DConditionLoadersMixin:
356
354
  for _, component in _pipeline.components.items():
357
355
  if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
358
356
  is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
359
- is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
357
+ is_sequential_cpu_offload = (
358
+ isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
359
+ or hasattr(component._hf_hook, "hooks")
360
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
361
+ )
360
362
 
361
363
  logger.info(
362
364
  "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
@@ -564,7 +566,7 @@ class UNet2DConditionLoadersMixin:
564
566
  def set_adapters(
565
567
  self,
566
568
  adapter_names: Union[List[str], str],
567
- weights: Optional[Union[List[float], float]] = None,
569
+ weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
568
570
  ):
569
571
  """
570
572
  Set the currently active adapters for use in the UNet.
@@ -597,9 +599,9 @@ class UNet2DConditionLoadersMixin:
597
599
 
598
600
  adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
599
601
 
600
- if weights is None:
601
- weights = [1.0] * len(adapter_names)
602
- elif isinstance(weights, float):
602
+ # Expand weights into a list, one entry per adapter
603
+ # examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
604
+ if not isinstance(weights, list):
603
605
  weights = [weights] * len(adapter_names)
604
606
 
605
607
  if len(adapter_names) != len(weights):
@@ -607,6 +609,13 @@ class UNet2DConditionLoadersMixin:
607
609
  f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
608
610
  )
609
611
 
612
+ # Set None values to default of 1.0
613
+ # e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
614
+ weights = [w if w is not None else 1.0 for w in weights]
615
+
616
+ # e.g. [{...}, 7] -> [{expanded dict...}, 7]
617
+ weights = _maybe_expand_lora_scales(self, weights)
618
+
610
619
  set_weights_and_activate_adapters(self, adapter_names, weights)
611
620
 
612
621
  def disable_lora(self):
@@ -748,6 +757,90 @@ class UNet2DConditionLoadersMixin:
748
757
  diffusers_name = diffusers_name.replace("proj.3", "norm")
749
758
  updated_state_dict[diffusers_name] = value
750
759
 
760
+ elif "perceiver_resampler.proj_in.weight" in state_dict:
761
+ # IP-Adapter Face ID Plus
762
+ id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
763
+ embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
764
+ hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
765
+ output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
766
+ heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
767
+
768
+ with init_context():
769
+ image_projection = IPAdapterFaceIDPlusImageProjection(
770
+ embed_dims=embed_dims,
771
+ output_dims=output_dims,
772
+ hidden_dims=hidden_dims,
773
+ heads=heads,
774
+ id_embeddings_dim=id_embeddings_dim,
775
+ )
776
+
777
+ for key, value in state_dict.items():
778
+ diffusers_name = key.replace("perceiver_resampler.", "")
779
+ diffusers_name = diffusers_name.replace("0.to", "attn.to")
780
+ diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
781
+ diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
782
+ diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
783
+ diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
784
+ diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
785
+ diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
786
+ diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
787
+ diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
788
+ diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
789
+ diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
790
+ diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
791
+ diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
792
+ diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
793
+ diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
794
+ diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
795
+ diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
796
+ diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
797
+ diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
798
+ diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
799
+ diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
800
+
801
+ if "norm1" in diffusers_name:
802
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
803
+ elif "norm2" in diffusers_name:
804
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
805
+ elif "to_kv" in diffusers_name:
806
+ v_chunk = value.chunk(2, dim=0)
807
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
808
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
809
+ elif "to_out" in diffusers_name:
810
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
811
+ elif "proj.0.weight" == diffusers_name:
812
+ updated_state_dict["proj.net.0.proj.weight"] = value
813
+ elif "proj.0.bias" == diffusers_name:
814
+ updated_state_dict["proj.net.0.proj.bias"] = value
815
+ elif "proj.2.weight" == diffusers_name:
816
+ updated_state_dict["proj.net.2.weight"] = value
817
+ elif "proj.2.bias" == diffusers_name:
818
+ updated_state_dict["proj.net.2.bias"] = value
819
+ else:
820
+ updated_state_dict[diffusers_name] = value
821
+
822
+ elif "norm.weight" in state_dict:
823
+ # IP-Adapter Face ID
824
+ id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
825
+ id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
826
+ multiplier = id_embeddings_dim_out // id_embeddings_dim_in
827
+ norm_layer = "norm.weight"
828
+ cross_attention_dim = state_dict[norm_layer].shape[0]
829
+ num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
830
+
831
+ with init_context():
832
+ image_projection = IPAdapterFaceIDImageProjection(
833
+ cross_attention_dim=cross_attention_dim,
834
+ image_embed_dim=id_embeddings_dim_in,
835
+ mult=multiplier,
836
+ num_tokens=num_tokens,
837
+ )
838
+
839
+ for key, value in state_dict.items():
840
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
841
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
842
+ updated_state_dict[diffusers_name] = value
843
+
751
844
  else:
752
845
  # IP-Adapter Plus
753
846
  num_image_text_embeds = state_dict["latents"].shape[1]
@@ -839,6 +932,7 @@ class UNet2DConditionLoadersMixin:
839
932
  AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
840
933
  )
841
934
  attn_procs[name] = attn_processor_class()
935
+
842
936
  else:
843
937
  attn_processor_class = (
844
938
  IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
@@ -851,6 +945,12 @@ class UNet2DConditionLoadersMixin:
851
945
  elif "proj.3.weight" in state_dict["image_proj"]:
852
946
  # IP-Adapter Full Face
853
947
  num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
948
+ elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
949
+ # IP-Adapter Face ID Plus
950
+ num_image_text_embeds += [4]
951
+ elif "norm.weight" in state_dict["image_proj"]:
952
+ # IP-Adapter Face ID
953
+ num_image_text_embeds += [4]
854
954
  else:
855
955
  # IP-Adapter Plus
856
956
  num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
@@ -902,102 +1002,55 @@ class UNet2DConditionLoadersMixin:
902
1002
 
903
1003
  self.to(dtype=self.dtype, device=self.device)
904
1004
 
905
-
906
- class FromOriginalUNetMixin:
907
- """
908
- Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
909
- """
910
-
911
- @classmethod
912
- @validate_hf_hub_args
913
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
914
- r"""
915
- Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
916
- `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
917
-
918
- Parameters:
919
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
920
- Can be either:
921
- - A link to the `.ckpt` file (for example
922
- `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
923
- - A path to a *file* containing all pipeline weights.
924
- config: (`dict`, *optional*):
925
- Dictionary containing the configuration of the model:
926
- torch_dtype (`str` or `torch.dtype`, *optional*):
927
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
928
- dtype is automatically derived from the model's weights.
929
- force_download (`bool`, *optional*, defaults to `False`):
930
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
931
- cached versions if they exist.
932
- cache_dir (`Union[str, os.PathLike]`, *optional*):
933
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
934
- is not used.
935
- resume_download (`bool`, *optional*, defaults to `False`):
936
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
937
- incompletely downloaded files are deleted.
938
- proxies (`Dict[str, str]`, *optional*):
939
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
940
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
941
- local_files_only (`bool`, *optional*, defaults to `False`):
942
- Whether to only load local model weights and configuration files or not. If set to True, the model
943
- won't be downloaded from the Hub.
944
- token (`str` or *bool*, *optional*):
945
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
946
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
947
- revision (`str`, *optional*, defaults to `"main"`):
948
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
949
- allowed by Git.
950
- kwargs (remaining dictionary of keyword arguments, *optional*):
951
- Can be used to overwrite load and saveable variables of the model.
952
-
953
- """
954
- class_name = cls.__name__
955
- if class_name != "StableCascadeUNet":
956
- raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
957
-
958
- config = kwargs.pop("config", None)
959
- resume_download = kwargs.pop("resume_download", False)
960
- force_download = kwargs.pop("force_download", False)
961
- proxies = kwargs.pop("proxies", None)
962
- token = kwargs.pop("token", None)
963
- cache_dir = kwargs.pop("cache_dir", None)
964
- local_files_only = kwargs.pop("local_files_only", None)
965
- revision = kwargs.pop("revision", None)
966
- torch_dtype = kwargs.pop("torch_dtype", None)
967
-
968
- checkpoint = load_single_file_model_checkpoint(
969
- pretrained_model_link_or_path,
970
- resume_download=resume_download,
971
- force_download=force_download,
972
- proxies=proxies,
973
- token=token,
974
- cache_dir=cache_dir,
975
- local_files_only=local_files_only,
976
- revision=revision,
977
- )
978
-
979
- if config is None:
980
- config = infer_stable_cascade_single_file_config(checkpoint)
981
- model_config = cls.load_config(**config, **kwargs)
982
- else:
983
- model_config = config
984
-
985
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
986
- with ctx():
987
- model = cls.from_config(model_config, **kwargs)
988
-
989
- diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
990
- if is_accelerate_available():
991
- unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
992
- if len(unexpected_keys) > 0:
993
- logger.warn(
994
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
995
- )
996
-
997
- else:
998
- model.load_state_dict(diffusers_format_checkpoint)
999
-
1000
- if torch_dtype is not None:
1001
- model.to(torch_dtype)
1002
-
1003
- return model
1005
+ def _load_ip_adapter_loras(self, state_dicts):
1006
+ lora_dicts = {}
1007
+ for key_id, name in enumerate(self.attn_processors.keys()):
1008
+ for i, state_dict in enumerate(state_dicts):
1009
+ if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
1010
+ if i not in lora_dicts:
1011
+ lora_dicts[i] = {}
1012
+ lora_dicts[i].update(
1013
+ {
1014
+ f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
1015
+ f"{key_id}.to_k_lora.down.weight"
1016
+ ]
1017
+ }
1018
+ )
1019
+ lora_dicts[i].update(
1020
+ {
1021
+ f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
1022
+ f"{key_id}.to_q_lora.down.weight"
1023
+ ]
1024
+ }
1025
+ )
1026
+ lora_dicts[i].update(
1027
+ {
1028
+ f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
1029
+ f"{key_id}.to_v_lora.down.weight"
1030
+ ]
1031
+ }
1032
+ )
1033
+ lora_dicts[i].update(
1034
+ {
1035
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
1036
+ f"{key_id}.to_out_lora.down.weight"
1037
+ ]
1038
+ }
1039
+ )
1040
+ lora_dicts[i].update(
1041
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
1042
+ )
1043
+ lora_dicts[i].update(
1044
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
1045
+ )
1046
+ lora_dicts[i].update(
1047
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
1048
+ )
1049
+ lora_dicts[i].update(
1050
+ {
1051
+ f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
1052
+ f"{key_id}.to_out_lora.up.weight"
1053
+ ]
1054
+ }
1055
+ )
1056
+ return lora_dicts
@@ -0,0 +1,163 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import copy
15
+ from typing import TYPE_CHECKING, Dict, List, Union
16
+
17
+ from ..utils import logging
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ # import here to avoid circular imports
22
+ from ..models import UNet2DConditionModel
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ def _translate_into_actual_layer_name(name):
28
+ """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
29
+ if name == "mid":
30
+ return "mid_block.attentions.0"
31
+
32
+ updown, block, attn = name.split(".")
33
+
34
+ updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
35
+ block = block.replace("block_", "")
36
+ attn = "attentions." + attn
37
+
38
+ return ".".join((updown, block, attn))
39
+
40
+
41
+ def _maybe_expand_lora_scales(
42
+ unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
43
+ ):
44
+ blocks_with_transformer = {
45
+ "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
46
+ "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
47
+ }
48
+ transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
49
+
50
+ expanded_weight_scales = [
51
+ _maybe_expand_lora_scales_for_one_adapter(
52
+ weight_for_adapter,
53
+ blocks_with_transformer,
54
+ transformer_per_block,
55
+ unet.state_dict(),
56
+ default_scale=default_scale,
57
+ )
58
+ for weight_for_adapter in weight_scales
59
+ ]
60
+
61
+ return expanded_weight_scales
62
+
63
+
64
+ def _maybe_expand_lora_scales_for_one_adapter(
65
+ scales: Union[float, Dict],
66
+ blocks_with_transformer: Dict[str, int],
67
+ transformer_per_block: Dict[str, int],
68
+ state_dict: None,
69
+ default_scale: float = 1.0,
70
+ ):
71
+ """
72
+ Expands the inputs into a more granular dictionary. See the example below for more details.
73
+
74
+ Parameters:
75
+ scales (`Union[float, Dict]`):
76
+ Scales dict to expand.
77
+ blocks_with_transformer (`Dict[str, int]`):
78
+ Dict with keys 'up' and 'down', showing which blocks have transformer layers
79
+ transformer_per_block (`Dict[str, int]`):
80
+ Dict with keys 'up' and 'down', showing how many transformer layers each block has
81
+
82
+ E.g. turns
83
+ ```python
84
+ scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
85
+ blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
86
+ transformer_per_block = {"down": 2, "up": 3}
87
+ ```
88
+ into
89
+ ```python
90
+ {
91
+ "down.block_1.0": 2,
92
+ "down.block_1.1": 2,
93
+ "down.block_2.0": 2,
94
+ "down.block_2.1": 2,
95
+ "mid": 3,
96
+ "up.block_0.0": 4,
97
+ "up.block_0.1": 4,
98
+ "up.block_0.2": 4,
99
+ "up.block_1.0": 5,
100
+ "up.block_1.1": 6,
101
+ "up.block_1.2": 7,
102
+ }
103
+ ```
104
+ """
105
+ if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
106
+ raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
107
+
108
+ if sorted(transformer_per_block.keys()) != ["down", "up"]:
109
+ raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
110
+
111
+ if not isinstance(scales, dict):
112
+ # don't expand if scales is a single number
113
+ return scales
114
+
115
+ scales = copy.deepcopy(scales)
116
+
117
+ if "mid" not in scales:
118
+ scales["mid"] = default_scale
119
+ elif isinstance(scales["mid"], list):
120
+ if len(scales["mid"]) == 1:
121
+ scales["mid"] = scales["mid"][0]
122
+ else:
123
+ raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
124
+
125
+ for updown in ["up", "down"]:
126
+ if updown not in scales:
127
+ scales[updown] = default_scale
128
+
129
+ # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
130
+ if not isinstance(scales[updown], dict):
131
+ scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
132
+
133
+ # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
134
+ for i in blocks_with_transformer[updown]:
135
+ block = f"block_{i}"
136
+ # set not assigned blocks to default scale
137
+ if block not in scales[updown]:
138
+ scales[updown][block] = default_scale
139
+ if not isinstance(scales[updown][block], list):
140
+ scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
141
+ elif len(scales[updown][block]) == 1:
142
+ # a list specifying scale to each masked IP input
143
+ scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
144
+ elif len(scales[updown][block]) != transformer_per_block[updown]:
145
+ raise ValueError(
146
+ f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
147
+ )
148
+
149
+ # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
150
+ for i in blocks_with_transformer[updown]:
151
+ block = f"block_{i}"
152
+ for tf_idx, value in enumerate(scales[updown][block]):
153
+ scales[f"{updown}.{block}.{tf_idx}"] = value
154
+
155
+ del scales[updown]
156
+
157
+ for layer in scales.keys():
158
+ if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159
+ raise ValueError(
160
+ f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
161
+ )
162
+
163
+ return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}