diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,828 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import inspect
15
+ import os
16
+ from collections import defaultdict
17
+ from contextlib import nullcontext
18
+ from functools import partial
19
+ from typing import Callable, Dict, List, Optional, Union
20
+
21
+ import safetensors
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from huggingface_hub.utils import validate_hf_hub_args
25
+ from torch import nn
26
+
27
+ from ..models.embeddings import ImageProjection, MLPProjection, Resampler
28
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
29
+ from ..utils import (
30
+ USE_PEFT_BACKEND,
31
+ _get_model_file,
32
+ delete_adapter_layers,
33
+ is_accelerate_available,
34
+ logging,
35
+ set_adapter_layers,
36
+ set_weights_and_activate_adapters,
37
+ )
38
+ from .utils import AttnProcsLayers
39
+
40
+
41
+ if is_accelerate_available():
42
+ from accelerate import init_empty_weights
43
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ TEXT_ENCODER_NAME = "text_encoder"
49
+ UNET_NAME = "unet"
50
+
51
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
52
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
53
+
54
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
55
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
56
+
57
+
58
+ class UNet2DConditionLoadersMixin:
59
+ """
60
+ Load LoRA layers into a [`UNet2DCondtionModel`].
61
+ """
62
+
63
+ text_encoder_name = TEXT_ENCODER_NAME
64
+ unet_name = UNET_NAME
65
+
66
+ @validate_hf_hub_args
67
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
68
+ r"""
69
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
70
+ defined in
71
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
72
+ and be a `torch.nn.Module` class.
73
+
74
+ Parameters:
75
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
76
+ Can be either:
77
+
78
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
79
+ the Hub.
80
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
81
+ with [`ModelMixin.save_pretrained`].
82
+ - A [torch state
83
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
84
+
85
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
86
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
87
+ is not used.
88
+ force_download (`bool`, *optional*, defaults to `False`):
89
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
90
+ cached versions if they exist.
91
+ resume_download (`bool`, *optional*, defaults to `False`):
92
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
93
+ incompletely downloaded files are deleted.
94
+ proxies (`Dict[str, str]`, *optional*):
95
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
96
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
97
+ local_files_only (`bool`, *optional*, defaults to `False`):
98
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
99
+ won't be downloaded from the Hub.
100
+ token (`str` or *bool*, *optional*):
101
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
102
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
103
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
104
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
105
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
106
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
107
+ argument to `True` will raise an error.
108
+ revision (`str`, *optional*, defaults to `"main"`):
109
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
110
+ allowed by Git.
111
+ subfolder (`str`, *optional*, defaults to `""`):
112
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
113
+ mirror (`str`, *optional*):
114
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
115
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
116
+ information.
117
+
118
+ Example:
119
+
120
+ ```py
121
+ from diffusers import AutoPipelineForText2Image
122
+ import torch
123
+
124
+ pipeline = AutoPipelineForText2Image.from_pretrained(
125
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
126
+ ).to("cuda")
127
+ pipeline.unet.load_attn_procs(
128
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
129
+ )
130
+ ```
131
+ """
132
+ from ..models.attention_processor import CustomDiffusionAttnProcessor
133
+ from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
134
+
135
+ cache_dir = kwargs.pop("cache_dir", None)
136
+ force_download = kwargs.pop("force_download", False)
137
+ resume_download = kwargs.pop("resume_download", False)
138
+ proxies = kwargs.pop("proxies", None)
139
+ local_files_only = kwargs.pop("local_files_only", None)
140
+ token = kwargs.pop("token", None)
141
+ revision = kwargs.pop("revision", None)
142
+ subfolder = kwargs.pop("subfolder", None)
143
+ weight_name = kwargs.pop("weight_name", None)
144
+ use_safetensors = kwargs.pop("use_safetensors", None)
145
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
146
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
147
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
148
+ network_alphas = kwargs.pop("network_alphas", None)
149
+
150
+ _pipeline = kwargs.pop("_pipeline", None)
151
+
152
+ is_network_alphas_none = network_alphas is None
153
+
154
+ allow_pickle = False
155
+
156
+ if use_safetensors is None:
157
+ use_safetensors = True
158
+ allow_pickle = True
159
+
160
+ user_agent = {
161
+ "file_type": "attn_procs_weights",
162
+ "framework": "pytorch",
163
+ }
164
+
165
+ if low_cpu_mem_usage and not is_accelerate_available():
166
+ low_cpu_mem_usage = False
167
+ logger.warning(
168
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
169
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
170
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
171
+ " install accelerate\n```\n."
172
+ )
173
+
174
+ model_file = None
175
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
176
+ # Let's first try to load .safetensors weights
177
+ if (use_safetensors and weight_name is None) or (
178
+ weight_name is not None and weight_name.endswith(".safetensors")
179
+ ):
180
+ try:
181
+ model_file = _get_model_file(
182
+ pretrained_model_name_or_path_or_dict,
183
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
184
+ cache_dir=cache_dir,
185
+ force_download=force_download,
186
+ resume_download=resume_download,
187
+ proxies=proxies,
188
+ local_files_only=local_files_only,
189
+ token=token,
190
+ revision=revision,
191
+ subfolder=subfolder,
192
+ user_agent=user_agent,
193
+ )
194
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
195
+ except IOError as e:
196
+ if not allow_pickle:
197
+ raise e
198
+ # try loading non-safetensors weights
199
+ pass
200
+ if model_file is None:
201
+ model_file = _get_model_file(
202
+ pretrained_model_name_or_path_or_dict,
203
+ weights_name=weight_name or LORA_WEIGHT_NAME,
204
+ cache_dir=cache_dir,
205
+ force_download=force_download,
206
+ resume_download=resume_download,
207
+ proxies=proxies,
208
+ local_files_only=local_files_only,
209
+ token=token,
210
+ revision=revision,
211
+ subfolder=subfolder,
212
+ user_agent=user_agent,
213
+ )
214
+ state_dict = torch.load(model_file, map_location="cpu")
215
+ else:
216
+ state_dict = pretrained_model_name_or_path_or_dict
217
+
218
+ # fill attn processors
219
+ lora_layers_list = []
220
+
221
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
222
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
223
+
224
+ if is_lora:
225
+ # correct keys
226
+ state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
227
+
228
+ if network_alphas is not None:
229
+ network_alphas_keys = list(network_alphas.keys())
230
+ used_network_alphas_keys = set()
231
+
232
+ lora_grouped_dict = defaultdict(dict)
233
+ mapped_network_alphas = {}
234
+
235
+ all_keys = list(state_dict.keys())
236
+ for key in all_keys:
237
+ value = state_dict.pop(key)
238
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
239
+ lora_grouped_dict[attn_processor_key][sub_key] = value
240
+
241
+ # Create another `mapped_network_alphas` dictionary so that we can properly map them.
242
+ if network_alphas is not None:
243
+ for k in network_alphas_keys:
244
+ if k.replace(".alpha", "") in key:
245
+ mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
246
+ used_network_alphas_keys.add(k)
247
+
248
+ if not is_network_alphas_none:
249
+ if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
250
+ raise ValueError(
251
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
252
+ )
253
+
254
+ if len(state_dict) > 0:
255
+ raise ValueError(
256
+ f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
257
+ )
258
+
259
+ for key, value_dict in lora_grouped_dict.items():
260
+ attn_processor = self
261
+ for sub_key in key.split("."):
262
+ attn_processor = getattr(attn_processor, sub_key)
263
+
264
+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
265
+ # or add_{k,v,q,out_proj}_proj_lora layers.
266
+ rank = value_dict["lora.down.weight"].shape[0]
267
+
268
+ if isinstance(attn_processor, LoRACompatibleConv):
269
+ in_features = attn_processor.in_channels
270
+ out_features = attn_processor.out_channels
271
+ kernel_size = attn_processor.kernel_size
272
+
273
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
274
+ with ctx():
275
+ lora = LoRAConv2dLayer(
276
+ in_features=in_features,
277
+ out_features=out_features,
278
+ rank=rank,
279
+ kernel_size=kernel_size,
280
+ stride=attn_processor.stride,
281
+ padding=attn_processor.padding,
282
+ network_alpha=mapped_network_alphas.get(key),
283
+ )
284
+ elif isinstance(attn_processor, LoRACompatibleLinear):
285
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
286
+ with ctx():
287
+ lora = LoRALinearLayer(
288
+ attn_processor.in_features,
289
+ attn_processor.out_features,
290
+ rank,
291
+ mapped_network_alphas.get(key),
292
+ )
293
+ else:
294
+ raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
295
+
296
+ value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
297
+ lora_layers_list.append((attn_processor, lora))
298
+
299
+ if low_cpu_mem_usage:
300
+ device = next(iter(value_dict.values())).device
301
+ dtype = next(iter(value_dict.values())).dtype
302
+ load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
303
+ else:
304
+ lora.load_state_dict(value_dict)
305
+
306
+ elif is_custom_diffusion:
307
+ attn_processors = {}
308
+ custom_diffusion_grouped_dict = defaultdict(dict)
309
+ for key, value in state_dict.items():
310
+ if len(value) == 0:
311
+ custom_diffusion_grouped_dict[key] = {}
312
+ else:
313
+ if "to_out" in key:
314
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
315
+ else:
316
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
317
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
318
+
319
+ for key, value_dict in custom_diffusion_grouped_dict.items():
320
+ if len(value_dict) == 0:
321
+ attn_processors[key] = CustomDiffusionAttnProcessor(
322
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
323
+ )
324
+ else:
325
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
326
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
327
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
328
+ attn_processors[key] = CustomDiffusionAttnProcessor(
329
+ train_kv=True,
330
+ train_q_out=train_q_out,
331
+ hidden_size=hidden_size,
332
+ cross_attention_dim=cross_attention_dim,
333
+ )
334
+ attn_processors[key].load_state_dict(value_dict)
335
+ elif USE_PEFT_BACKEND:
336
+ # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
337
+ # on the Unet
338
+ pass
339
+ else:
340
+ raise ValueError(
341
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
342
+ )
343
+
344
+ # <Unsafe code
345
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
346
+ # Now we remove any existing hooks to
347
+ is_model_cpu_offload = False
348
+ is_sequential_cpu_offload = False
349
+
350
+ # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
351
+ if not USE_PEFT_BACKEND:
352
+ if _pipeline is not None:
353
+ for _, component in _pipeline.components.items():
354
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
355
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
356
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
357
+
358
+ logger.info(
359
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
360
+ )
361
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
362
+
363
+ # only custom diffusion needs to set attn processors
364
+ if is_custom_diffusion:
365
+ self.set_attn_processor(attn_processors)
366
+
367
+ # set lora layers
368
+ for target_module, lora_layer in lora_layers_list:
369
+ target_module.set_lora_layer(lora_layer)
370
+
371
+ self.to(dtype=self.dtype, device=self.device)
372
+
373
+ # Offload back.
374
+ if is_model_cpu_offload:
375
+ _pipeline.enable_model_cpu_offload()
376
+ elif is_sequential_cpu_offload:
377
+ _pipeline.enable_sequential_cpu_offload()
378
+ # Unsafe code />
379
+
380
+ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
381
+ is_new_lora_format = all(
382
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
383
+ )
384
+ if is_new_lora_format:
385
+ # Strip the `"unet"` prefix.
386
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
387
+ if is_text_encoder_present:
388
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
389
+ logger.warn(warn_message)
390
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
391
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
392
+
393
+ # change processor format to 'pure' LoRACompatibleLinear format
394
+ if any("processor" in k.split(".") for k in state_dict.keys()):
395
+
396
+ def format_to_lora_compatible(key):
397
+ if "processor" not in key.split("."):
398
+ return key
399
+ return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
400
+
401
+ state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
402
+
403
+ if network_alphas is not None:
404
+ network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
405
+ return state_dict, network_alphas
406
+
407
+ def save_attn_procs(
408
+ self,
409
+ save_directory: Union[str, os.PathLike],
410
+ is_main_process: bool = True,
411
+ weight_name: str = None,
412
+ save_function: Callable = None,
413
+ safe_serialization: bool = True,
414
+ **kwargs,
415
+ ):
416
+ r"""
417
+ Save attention processor layers to a directory so that it can be reloaded with the
418
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
419
+
420
+ Arguments:
421
+ save_directory (`str` or `os.PathLike`):
422
+ Directory to save an attention processor to (will be created if it doesn't exist).
423
+ is_main_process (`bool`, *optional*, defaults to `True`):
424
+ Whether the process calling this is the main process or not. Useful during distributed training and you
425
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
426
+ process to avoid race conditions.
427
+ save_function (`Callable`):
428
+ The function to use to save the state dictionary. Useful during distributed training when you need to
429
+ replace `torch.save` with another method. Can be configured with the environment variable
430
+ `DIFFUSERS_SAVE_MODE`.
431
+ safe_serialization (`bool`, *optional*, defaults to `True`):
432
+ Whether to save the model using `safetensors` or with `pickle`.
433
+
434
+ Example:
435
+
436
+ ```py
437
+ import torch
438
+ from diffusers import DiffusionPipeline
439
+
440
+ pipeline = DiffusionPipeline.from_pretrained(
441
+ "CompVis/stable-diffusion-v1-4",
442
+ torch_dtype=torch.float16,
443
+ ).to("cuda")
444
+ pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
445
+ pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
446
+ ```
447
+ """
448
+ from ..models.attention_processor import (
449
+ CustomDiffusionAttnProcessor,
450
+ CustomDiffusionAttnProcessor2_0,
451
+ CustomDiffusionXFormersAttnProcessor,
452
+ )
453
+
454
+ if os.path.isfile(save_directory):
455
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
456
+ return
457
+
458
+ if save_function is None:
459
+ if safe_serialization:
460
+
461
+ def save_function(weights, filename):
462
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
463
+
464
+ else:
465
+ save_function = torch.save
466
+
467
+ os.makedirs(save_directory, exist_ok=True)
468
+
469
+ is_custom_diffusion = any(
470
+ isinstance(
471
+ x,
472
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
473
+ )
474
+ for (_, x) in self.attn_processors.items()
475
+ )
476
+ if is_custom_diffusion:
477
+ model_to_save = AttnProcsLayers(
478
+ {
479
+ y: x
480
+ for (y, x) in self.attn_processors.items()
481
+ if isinstance(
482
+ x,
483
+ (
484
+ CustomDiffusionAttnProcessor,
485
+ CustomDiffusionAttnProcessor2_0,
486
+ CustomDiffusionXFormersAttnProcessor,
487
+ ),
488
+ )
489
+ }
490
+ )
491
+ state_dict = model_to_save.state_dict()
492
+ for name, attn in self.attn_processors.items():
493
+ if len(attn.state_dict()) == 0:
494
+ state_dict[name] = {}
495
+ else:
496
+ model_to_save = AttnProcsLayers(self.attn_processors)
497
+ state_dict = model_to_save.state_dict()
498
+
499
+ if weight_name is None:
500
+ if safe_serialization:
501
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
502
+ else:
503
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
504
+
505
+ # Save the model
506
+ save_function(state_dict, os.path.join(save_directory, weight_name))
507
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
508
+
509
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
510
+ self.lora_scale = lora_scale
511
+ self._safe_fusing = safe_fusing
512
+ self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
513
+
514
+ def _fuse_lora_apply(self, module, adapter_names=None):
515
+ if not USE_PEFT_BACKEND:
516
+ if hasattr(module, "_fuse_lora"):
517
+ module._fuse_lora(self.lora_scale, self._safe_fusing)
518
+
519
+ if adapter_names is not None:
520
+ raise ValueError(
521
+ "The `adapter_names` argument is not supported in your environment. Please switch"
522
+ " to PEFT backend to use this argument by installing latest PEFT and transformers."
523
+ " `pip install -U peft transformers`"
524
+ )
525
+ else:
526
+ from peft.tuners.tuners_utils import BaseTunerLayer
527
+
528
+ merge_kwargs = {"safe_merge": self._safe_fusing}
529
+
530
+ if isinstance(module, BaseTunerLayer):
531
+ if self.lora_scale != 1.0:
532
+ module.scale_layer(self.lora_scale)
533
+
534
+ # For BC with prevous PEFT versions, we need to check the signature
535
+ # of the `merge` method to see if it supports the `adapter_names` argument.
536
+ supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
537
+ if "adapter_names" in supported_merge_kwargs:
538
+ merge_kwargs["adapter_names"] = adapter_names
539
+ elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
540
+ raise ValueError(
541
+ "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
542
+ " to the latest version of PEFT. `pip install -U peft`"
543
+ )
544
+
545
+ module.merge(**merge_kwargs)
546
+
547
+ def unfuse_lora(self):
548
+ self.apply(self._unfuse_lora_apply)
549
+
550
+ def _unfuse_lora_apply(self, module):
551
+ if not USE_PEFT_BACKEND:
552
+ if hasattr(module, "_unfuse_lora"):
553
+ module._unfuse_lora()
554
+ else:
555
+ from peft.tuners.tuners_utils import BaseTunerLayer
556
+
557
+ if isinstance(module, BaseTunerLayer):
558
+ module.unmerge()
559
+
560
+ def set_adapters(
561
+ self,
562
+ adapter_names: Union[List[str], str],
563
+ weights: Optional[Union[List[float], float]] = None,
564
+ ):
565
+ """
566
+ Set the currently active adapters for use in the UNet.
567
+
568
+ Args:
569
+ adapter_names (`List[str]` or `str`):
570
+ The names of the adapters to use.
571
+ adapter_weights (`Union[List[float], float]`, *optional*):
572
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
573
+ adapters.
574
+
575
+ Example:
576
+
577
+ ```py
578
+ from diffusers import AutoPipelineForText2Image
579
+ import torch
580
+
581
+ pipeline = AutoPipelineForText2Image.from_pretrained(
582
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
583
+ ).to("cuda")
584
+ pipeline.load_lora_weights(
585
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
586
+ )
587
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
588
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
589
+ ```
590
+ """
591
+ if not USE_PEFT_BACKEND:
592
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
593
+
594
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
595
+
596
+ if weights is None:
597
+ weights = [1.0] * len(adapter_names)
598
+ elif isinstance(weights, float):
599
+ weights = [weights] * len(adapter_names)
600
+
601
+ if len(adapter_names) != len(weights):
602
+ raise ValueError(
603
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
604
+ )
605
+
606
+ set_weights_and_activate_adapters(self, adapter_names, weights)
607
+
608
+ def disable_lora(self):
609
+ """
610
+ Disable the UNet's active LoRA layers.
611
+
612
+ Example:
613
+
614
+ ```py
615
+ from diffusers import AutoPipelineForText2Image
616
+ import torch
617
+
618
+ pipeline = AutoPipelineForText2Image.from_pretrained(
619
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
620
+ ).to("cuda")
621
+ pipeline.load_lora_weights(
622
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
623
+ )
624
+ pipeline.disable_lora()
625
+ ```
626
+ """
627
+ if not USE_PEFT_BACKEND:
628
+ raise ValueError("PEFT backend is required for this method.")
629
+ set_adapter_layers(self, enabled=False)
630
+
631
+ def enable_lora(self):
632
+ """
633
+ Enable the UNet's active LoRA layers.
634
+
635
+ Example:
636
+
637
+ ```py
638
+ from diffusers import AutoPipelineForText2Image
639
+ import torch
640
+
641
+ pipeline = AutoPipelineForText2Image.from_pretrained(
642
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
643
+ ).to("cuda")
644
+ pipeline.load_lora_weights(
645
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
646
+ )
647
+ pipeline.enable_lora()
648
+ ```
649
+ """
650
+ if not USE_PEFT_BACKEND:
651
+ raise ValueError("PEFT backend is required for this method.")
652
+ set_adapter_layers(self, enabled=True)
653
+
654
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
655
+ """
656
+ Delete an adapter's LoRA layers from the UNet.
657
+
658
+ Args:
659
+ adapter_names (`Union[List[str], str]`):
660
+ The names (single string or list of strings) of the adapter to delete.
661
+
662
+ Example:
663
+
664
+ ```py
665
+ from diffusers import AutoPipelineForText2Image
666
+ import torch
667
+
668
+ pipeline = AutoPipelineForText2Image.from_pretrained(
669
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
670
+ ).to("cuda")
671
+ pipeline.load_lora_weights(
672
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
673
+ )
674
+ pipeline.delete_adapters("cinematic")
675
+ ```
676
+ """
677
+ if not USE_PEFT_BACKEND:
678
+ raise ValueError("PEFT backend is required for this method.")
679
+
680
+ if isinstance(adapter_names, str):
681
+ adapter_names = [adapter_names]
682
+
683
+ for adapter_name in adapter_names:
684
+ delete_adapter_layers(self, adapter_name)
685
+
686
+ # Pop also the corresponding adapter from the config
687
+ if hasattr(self, "peft_config"):
688
+ self.peft_config.pop(adapter_name, None)
689
+
690
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
691
+ updated_state_dict = {}
692
+ image_projection = None
693
+
694
+ if "proj.weight" in state_dict:
695
+ # IP-Adapter
696
+ num_image_text_embeds = 4
697
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
698
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
699
+
700
+ image_projection = ImageProjection(
701
+ cross_attention_dim=cross_attention_dim,
702
+ image_embed_dim=clip_embeddings_dim,
703
+ num_image_text_embeds=num_image_text_embeds,
704
+ )
705
+
706
+ for key, value in state_dict.items():
707
+ diffusers_name = key.replace("proj", "image_embeds")
708
+ updated_state_dict[diffusers_name] = value
709
+
710
+ elif "proj.3.weight" in state_dict:
711
+ # IP-Adapter Full
712
+ clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
713
+ cross_attention_dim = state_dict["proj.3.weight"].shape[0]
714
+
715
+ image_projection = MLPProjection(
716
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
717
+ )
718
+
719
+ for key, value in state_dict.items():
720
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
721
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
722
+ diffusers_name = diffusers_name.replace("proj.3", "norm")
723
+ updated_state_dict[diffusers_name] = value
724
+
725
+ else:
726
+ # IP-Adapter Plus
727
+ num_image_text_embeds = state_dict["latents"].shape[1]
728
+ embed_dims = state_dict["proj_in.weight"].shape[1]
729
+ output_dims = state_dict["proj_out.weight"].shape[0]
730
+ hidden_dims = state_dict["latents"].shape[2]
731
+ heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
732
+
733
+ image_projection = Resampler(
734
+ embed_dims=embed_dims,
735
+ output_dims=output_dims,
736
+ hidden_dims=hidden_dims,
737
+ heads=heads,
738
+ num_queries=num_image_text_embeds,
739
+ )
740
+
741
+ for key, value in state_dict.items():
742
+ diffusers_name = key.replace("0.to", "2.to")
743
+ diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
744
+ diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
745
+ diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
746
+ diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
747
+
748
+ if "norm1" in diffusers_name:
749
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
750
+ elif "norm2" in diffusers_name:
751
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
752
+ elif "to_kv" in diffusers_name:
753
+ v_chunk = value.chunk(2, dim=0)
754
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
755
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
756
+ elif "to_out" in diffusers_name:
757
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
758
+ else:
759
+ updated_state_dict[diffusers_name] = value
760
+
761
+ image_projection.load_state_dict(updated_state_dict)
762
+ return image_projection
763
+
764
+ def _load_ip_adapter_weights(self, state_dict):
765
+ from ..models.attention_processor import (
766
+ AttnProcessor,
767
+ AttnProcessor2_0,
768
+ IPAdapterAttnProcessor,
769
+ IPAdapterAttnProcessor2_0,
770
+ )
771
+
772
+ if "proj.weight" in state_dict["image_proj"]:
773
+ # IP-Adapter
774
+ num_image_text_embeds = 4
775
+ elif "proj.3.weight" in state_dict["image_proj"]:
776
+ # IP-Adapter Full Face
777
+ num_image_text_embeds = 257 # 256 CLIP tokens + 1 CLS token
778
+ else:
779
+ # IP-Adapter Plus
780
+ num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
781
+
782
+ # Set encoder_hid_proj after loading ip_adapter weights,
783
+ # because `Resampler` also has `attn_processors`.
784
+ self.encoder_hid_proj = None
785
+
786
+ # set ip-adapter cross-attention processors & load state_dict
787
+ attn_procs = {}
788
+ key_id = 1
789
+ for name in self.attn_processors.keys():
790
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
791
+ if name.startswith("mid_block"):
792
+ hidden_size = self.config.block_out_channels[-1]
793
+ elif name.startswith("up_blocks"):
794
+ block_id = int(name[len("up_blocks.")])
795
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
796
+ elif name.startswith("down_blocks"):
797
+ block_id = int(name[len("down_blocks.")])
798
+ hidden_size = self.config.block_out_channels[block_id]
799
+ if cross_attention_dim is None or "motion_modules" in name:
800
+ attn_processor_class = (
801
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
802
+ )
803
+ attn_procs[name] = attn_processor_class()
804
+ else:
805
+ attn_processor_class = (
806
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
807
+ )
808
+ attn_procs[name] = attn_processor_class(
809
+ hidden_size=hidden_size,
810
+ cross_attention_dim=cross_attention_dim,
811
+ scale=1.0,
812
+ num_tokens=num_image_text_embeds,
813
+ ).to(dtype=self.dtype, device=self.device)
814
+
815
+ value_dict = {}
816
+ for k, w in attn_procs[name].state_dict().items():
817
+ value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
818
+
819
+ attn_procs[name].load_state_dict(value_dict)
820
+ key_id += 2
821
+
822
+ self.set_attn_processor(attn_procs)
823
+
824
+ # convert IP-Adapter Image Projection layers to diffusers
825
+ image_projection = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
826
+
827
+ self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
828
+ self.config.encoder_hid_dim_type = "ip_image_proj"