diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2252 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from typing import Callable, Dict, List, Optional, Union
16
+
17
+ import torch
18
+ from huggingface_hub.utils import validate_hf_hub_args
19
+
20
+ from ..utils import (
21
+ USE_PEFT_BACKEND,
22
+ convert_state_dict_to_diffusers,
23
+ convert_state_dict_to_peft,
24
+ convert_unet_state_dict_to_peft,
25
+ deprecate,
26
+ get_adapter_name,
27
+ get_peft_kwargs,
28
+ is_peft_version,
29
+ is_transformers_available,
30
+ logging,
31
+ scale_lora_layers,
32
+ )
33
+ from .lora_base import LoraBaseMixin
34
+ from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
35
+
36
+
37
+ if is_transformers_available():
38
+ from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ TEXT_ENCODER_NAME = "text_encoder"
43
+ UNET_NAME = "unet"
44
+ TRANSFORMER_NAME = "transformer"
45
+
46
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
47
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
48
+
49
+
50
+ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
51
+ r"""
52
+ Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
53
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
54
+ """
55
+
56
+ _lora_loadable_modules = ["unet", "text_encoder"]
57
+ unet_name = UNET_NAME
58
+ text_encoder_name = TEXT_ENCODER_NAME
59
+
60
+ def load_lora_weights(
61
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
62
+ ):
63
+ """
64
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
65
+ `self.text_encoder`.
66
+
67
+ All kwargs are forwarded to `self.lora_state_dict`.
68
+
69
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
70
+ loaded.
71
+
72
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
73
+ loaded into `self.unet`.
74
+
75
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
76
+ dict is loaded into `self.text_encoder`.
77
+
78
+ Parameters:
79
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
80
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
81
+ kwargs (`dict`, *optional*):
82
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
83
+ adapter_name (`str`, *optional*):
84
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
85
+ `default_{i}` where i is the total number of adapters being loaded.
86
+ """
87
+ if not USE_PEFT_BACKEND:
88
+ raise ValueError("PEFT backend is required for this method.")
89
+
90
+ # if a dict is passed, copy it instead of modifying it inplace
91
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
92
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
93
+
94
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
95
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
96
+
97
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
98
+ if not is_correct_format:
99
+ raise ValueError("Invalid LoRA checkpoint.")
100
+
101
+ self.load_lora_into_unet(
102
+ state_dict,
103
+ network_alphas=network_alphas,
104
+ unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
105
+ adapter_name=adapter_name,
106
+ _pipeline=self,
107
+ )
108
+ self.load_lora_into_text_encoder(
109
+ state_dict,
110
+ network_alphas=network_alphas,
111
+ text_encoder=getattr(self, self.text_encoder_name)
112
+ if not hasattr(self, "text_encoder")
113
+ else self.text_encoder,
114
+ lora_scale=self.lora_scale,
115
+ adapter_name=adapter_name,
116
+ _pipeline=self,
117
+ )
118
+
119
+ @classmethod
120
+ @validate_hf_hub_args
121
+ def lora_state_dict(
122
+ cls,
123
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
124
+ **kwargs,
125
+ ):
126
+ r"""
127
+ Return state dict for lora weights and the network alphas.
128
+
129
+ <Tip warning={true}>
130
+
131
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
132
+
133
+ This function is experimental and might change in the future.
134
+
135
+ </Tip>
136
+
137
+ Parameters:
138
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
139
+ Can be either:
140
+
141
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
142
+ the Hub.
143
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
144
+ with [`ModelMixin.save_pretrained`].
145
+ - A [torch state
146
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
147
+
148
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
149
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
150
+ is not used.
151
+ force_download (`bool`, *optional*, defaults to `False`):
152
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
153
+ cached versions if they exist.
154
+
155
+ proxies (`Dict[str, str]`, *optional*):
156
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
157
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
158
+ local_files_only (`bool`, *optional*, defaults to `False`):
159
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
160
+ won't be downloaded from the Hub.
161
+ token (`str` or *bool*, *optional*):
162
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
163
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
164
+ revision (`str`, *optional*, defaults to `"main"`):
165
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
166
+ allowed by Git.
167
+ subfolder (`str`, *optional*, defaults to `""`):
168
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
169
+ weight_name (`str`, *optional*, defaults to None):
170
+ Name of the serialized state dict file.
171
+ """
172
+ # Load the main state dict first which has the LoRA layers for either of
173
+ # UNet and text encoder or both.
174
+ cache_dir = kwargs.pop("cache_dir", None)
175
+ force_download = kwargs.pop("force_download", False)
176
+ proxies = kwargs.pop("proxies", None)
177
+ local_files_only = kwargs.pop("local_files_only", None)
178
+ token = kwargs.pop("token", None)
179
+ revision = kwargs.pop("revision", None)
180
+ subfolder = kwargs.pop("subfolder", None)
181
+ weight_name = kwargs.pop("weight_name", None)
182
+ unet_config = kwargs.pop("unet_config", None)
183
+ use_safetensors = kwargs.pop("use_safetensors", None)
184
+
185
+ allow_pickle = False
186
+ if use_safetensors is None:
187
+ use_safetensors = True
188
+ allow_pickle = True
189
+
190
+ user_agent = {
191
+ "file_type": "attn_procs_weights",
192
+ "framework": "pytorch",
193
+ }
194
+
195
+ state_dict = cls._fetch_state_dict(
196
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
197
+ weight_name=weight_name,
198
+ use_safetensors=use_safetensors,
199
+ local_files_only=local_files_only,
200
+ cache_dir=cache_dir,
201
+ force_download=force_download,
202
+ proxies=proxies,
203
+ token=token,
204
+ revision=revision,
205
+ subfolder=subfolder,
206
+ user_agent=user_agent,
207
+ allow_pickle=allow_pickle,
208
+ )
209
+
210
+ network_alphas = None
211
+ # TODO: replace it with a method from `state_dict_utils`
212
+ if all(
213
+ (
214
+ k.startswith("lora_te_")
215
+ or k.startswith("lora_unet_")
216
+ or k.startswith("lora_te1_")
217
+ or k.startswith("lora_te2_")
218
+ )
219
+ for k in state_dict.keys()
220
+ ):
221
+ # Map SDXL blocks correctly.
222
+ if unet_config is not None:
223
+ # use unet config to remap block numbers
224
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
225
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
226
+
227
+ return state_dict, network_alphas
228
+
229
+ @classmethod
230
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
231
+ """
232
+ This will load the LoRA layers specified in `state_dict` into `unet`.
233
+
234
+ Parameters:
235
+ state_dict (`dict`):
236
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
237
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
238
+ encoder lora layers.
239
+ network_alphas (`Dict[str, float]`):
240
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
241
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
242
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
243
+ unet (`UNet2DConditionModel`):
244
+ The UNet model to load the LoRA layers into.
245
+ adapter_name (`str`, *optional*):
246
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
247
+ `default_{i}` where i is the total number of adapters being loaded.
248
+ """
249
+ if not USE_PEFT_BACKEND:
250
+ raise ValueError("PEFT backend is required for this method.")
251
+
252
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
253
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
254
+ # their prefixes.
255
+ keys = list(state_dict.keys())
256
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
257
+ if not only_text_encoder:
258
+ # Load the layers corresponding to UNet.
259
+ logger.info(f"Loading {cls.unet_name}.")
260
+ unet.load_attn_procs(
261
+ state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
262
+ )
263
+
264
+ @classmethod
265
+ def load_lora_into_text_encoder(
266
+ cls,
267
+ state_dict,
268
+ network_alphas,
269
+ text_encoder,
270
+ prefix=None,
271
+ lora_scale=1.0,
272
+ adapter_name=None,
273
+ _pipeline=None,
274
+ ):
275
+ """
276
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
277
+
278
+ Parameters:
279
+ state_dict (`dict`):
280
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
281
+ additional `text_encoder` to distinguish between unet lora layers.
282
+ network_alphas (`Dict[str, float]`):
283
+ See `LoRALinearLayer` for more details.
284
+ text_encoder (`CLIPTextModel`):
285
+ The text encoder model to load the LoRA layers into.
286
+ prefix (`str`):
287
+ Expected prefix of the `text_encoder` in the `state_dict`.
288
+ lora_scale (`float`):
289
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
290
+ lora layer.
291
+ adapter_name (`str`, *optional*):
292
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
293
+ `default_{i}` where i is the total number of adapters being loaded.
294
+ """
295
+ if not USE_PEFT_BACKEND:
296
+ raise ValueError("PEFT backend is required for this method.")
297
+
298
+ from peft import LoraConfig
299
+
300
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
301
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
302
+ # their prefixes.
303
+ keys = list(state_dict.keys())
304
+ prefix = cls.text_encoder_name if prefix is None else prefix
305
+
306
+ # Safe prefix to check with.
307
+ if any(cls.text_encoder_name in key for key in keys):
308
+ # Load the layers corresponding to text encoder and make necessary adjustments.
309
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
310
+ text_encoder_lora_state_dict = {
311
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
312
+ }
313
+
314
+ if len(text_encoder_lora_state_dict) > 0:
315
+ logger.info(f"Loading {prefix}.")
316
+ rank = {}
317
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
318
+
319
+ # convert state dict
320
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
321
+
322
+ for name, _ in text_encoder_attn_modules(text_encoder):
323
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
324
+ rank_key = f"{name}.{module}.lora_B.weight"
325
+ if rank_key not in text_encoder_lora_state_dict:
326
+ continue
327
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
328
+
329
+ for name, _ in text_encoder_mlp_modules(text_encoder):
330
+ for module in ("fc1", "fc2"):
331
+ rank_key = f"{name}.{module}.lora_B.weight"
332
+ if rank_key not in text_encoder_lora_state_dict:
333
+ continue
334
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
335
+
336
+ if network_alphas is not None:
337
+ alpha_keys = [
338
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
339
+ ]
340
+ network_alphas = {
341
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
342
+ }
343
+
344
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
345
+ if "use_dora" in lora_config_kwargs:
346
+ if lora_config_kwargs["use_dora"]:
347
+ if is_peft_version("<", "0.9.0"):
348
+ raise ValueError(
349
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
350
+ )
351
+ else:
352
+ if is_peft_version("<", "0.9.0"):
353
+ lora_config_kwargs.pop("use_dora")
354
+ lora_config = LoraConfig(**lora_config_kwargs)
355
+
356
+ # adapter_name
357
+ if adapter_name is None:
358
+ adapter_name = get_adapter_name(text_encoder)
359
+
360
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
361
+
362
+ # inject LoRA layers and load the state dict
363
+ # in transformers we automatically check whether the adapter name is already in use or not
364
+ text_encoder.load_adapter(
365
+ adapter_name=adapter_name,
366
+ adapter_state_dict=text_encoder_lora_state_dict,
367
+ peft_config=lora_config,
368
+ )
369
+
370
+ # scale LoRA layers with `lora_scale`
371
+ scale_lora_layers(text_encoder, weight=lora_scale)
372
+
373
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
374
+
375
+ # Offload back.
376
+ if is_model_cpu_offload:
377
+ _pipeline.enable_model_cpu_offload()
378
+ elif is_sequential_cpu_offload:
379
+ _pipeline.enable_sequential_cpu_offload()
380
+ # Unsafe code />
381
+
382
+ @classmethod
383
+ def save_lora_weights(
384
+ cls,
385
+ save_directory: Union[str, os.PathLike],
386
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
387
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
388
+ is_main_process: bool = True,
389
+ weight_name: str = None,
390
+ save_function: Callable = None,
391
+ safe_serialization: bool = True,
392
+ ):
393
+ r"""
394
+ Save the LoRA parameters corresponding to the UNet and text encoder.
395
+
396
+ Arguments:
397
+ save_directory (`str` or `os.PathLike`):
398
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
399
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
400
+ State dict of the LoRA layers corresponding to the `unet`.
401
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
402
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
403
+ encoder LoRA state dict because it comes from 🤗 Transformers.
404
+ is_main_process (`bool`, *optional*, defaults to `True`):
405
+ Whether the process calling this is the main process or not. Useful during distributed training and you
406
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
407
+ process to avoid race conditions.
408
+ save_function (`Callable`):
409
+ The function to use to save the state dictionary. Useful during distributed training when you need to
410
+ replace `torch.save` with another method. Can be configured with the environment variable
411
+ `DIFFUSERS_SAVE_MODE`.
412
+ safe_serialization (`bool`, *optional*, defaults to `True`):
413
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
414
+ """
415
+ state_dict = {}
416
+
417
+ if not (unet_lora_layers or text_encoder_lora_layers):
418
+ raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
419
+
420
+ if unet_lora_layers:
421
+ state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
422
+
423
+ if text_encoder_lora_layers:
424
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
425
+
426
+ # Save the model
427
+ cls.write_lora_layers(
428
+ state_dict=state_dict,
429
+ save_directory=save_directory,
430
+ is_main_process=is_main_process,
431
+ weight_name=weight_name,
432
+ save_function=save_function,
433
+ safe_serialization=safe_serialization,
434
+ )
435
+
436
+ def fuse_lora(
437
+ self,
438
+ components: List[str] = ["unet", "text_encoder"],
439
+ lora_scale: float = 1.0,
440
+ safe_fusing: bool = False,
441
+ adapter_names: Optional[List[str]] = None,
442
+ **kwargs,
443
+ ):
444
+ r"""
445
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
446
+
447
+ <Tip warning={true}>
448
+
449
+ This is an experimental API.
450
+
451
+ </Tip>
452
+
453
+ Args:
454
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
455
+ lora_scale (`float`, defaults to 1.0):
456
+ Controls how much to influence the outputs with the LoRA parameters.
457
+ safe_fusing (`bool`, defaults to `False`):
458
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
459
+ adapter_names (`List[str]`, *optional*):
460
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
461
+
462
+ Example:
463
+
464
+ ```py
465
+ from diffusers import DiffusionPipeline
466
+ import torch
467
+
468
+ pipeline = DiffusionPipeline.from_pretrained(
469
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
470
+ ).to("cuda")
471
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
472
+ pipeline.fuse_lora(lora_scale=0.7)
473
+ ```
474
+ """
475
+ super().fuse_lora(
476
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
477
+ )
478
+
479
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
480
+ r"""
481
+ Reverses the effect of
482
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
483
+
484
+ <Tip warning={true}>
485
+
486
+ This is an experimental API.
487
+
488
+ </Tip>
489
+
490
+ Args:
491
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
492
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
493
+ unfuse_text_encoder (`bool`, defaults to `True`):
494
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
495
+ LoRA parameters then it won't have any effect.
496
+ """
497
+ super().unfuse_lora(components=components)
498
+
499
+
500
+ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
501
+ r"""
502
+ Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
503
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
504
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
505
+ """
506
+
507
+ _lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
508
+ unet_name = UNET_NAME
509
+ text_encoder_name = TEXT_ENCODER_NAME
510
+
511
+ def load_lora_weights(
512
+ self,
513
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
514
+ adapter_name: Optional[str] = None,
515
+ **kwargs,
516
+ ):
517
+ """
518
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
519
+ `self.text_encoder`.
520
+
521
+ All kwargs are forwarded to `self.lora_state_dict`.
522
+
523
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
524
+ loaded.
525
+
526
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
527
+ loaded into `self.unet`.
528
+
529
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
530
+ dict is loaded into `self.text_encoder`.
531
+
532
+ Parameters:
533
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
534
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
535
+ adapter_name (`str`, *optional*):
536
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
537
+ `default_{i}` where i is the total number of adapters being loaded.
538
+ kwargs (`dict`, *optional*):
539
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
540
+ """
541
+ if not USE_PEFT_BACKEND:
542
+ raise ValueError("PEFT backend is required for this method.")
543
+
544
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
545
+ # it here explicitly to be able to tell that it's coming from an SDXL
546
+ # pipeline.
547
+
548
+ # if a dict is passed, copy it instead of modifying it inplace
549
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
550
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
551
+
552
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
553
+ state_dict, network_alphas = self.lora_state_dict(
554
+ pretrained_model_name_or_path_or_dict,
555
+ unet_config=self.unet.config,
556
+ **kwargs,
557
+ )
558
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
559
+ if not is_correct_format:
560
+ raise ValueError("Invalid LoRA checkpoint.")
561
+
562
+ self.load_lora_into_unet(
563
+ state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
564
+ )
565
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
566
+ if len(text_encoder_state_dict) > 0:
567
+ self.load_lora_into_text_encoder(
568
+ text_encoder_state_dict,
569
+ network_alphas=network_alphas,
570
+ text_encoder=self.text_encoder,
571
+ prefix="text_encoder",
572
+ lora_scale=self.lora_scale,
573
+ adapter_name=adapter_name,
574
+ _pipeline=self,
575
+ )
576
+
577
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
578
+ if len(text_encoder_2_state_dict) > 0:
579
+ self.load_lora_into_text_encoder(
580
+ text_encoder_2_state_dict,
581
+ network_alphas=network_alphas,
582
+ text_encoder=self.text_encoder_2,
583
+ prefix="text_encoder_2",
584
+ lora_scale=self.lora_scale,
585
+ adapter_name=adapter_name,
586
+ _pipeline=self,
587
+ )
588
+
589
+ @classmethod
590
+ @validate_hf_hub_args
591
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
592
+ def lora_state_dict(
593
+ cls,
594
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
595
+ **kwargs,
596
+ ):
597
+ r"""
598
+ Return state dict for lora weights and the network alphas.
599
+
600
+ <Tip warning={true}>
601
+
602
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
603
+
604
+ This function is experimental and might change in the future.
605
+
606
+ </Tip>
607
+
608
+ Parameters:
609
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
610
+ Can be either:
611
+
612
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
613
+ the Hub.
614
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
615
+ with [`ModelMixin.save_pretrained`].
616
+ - A [torch state
617
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
618
+
619
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
620
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
621
+ is not used.
622
+ force_download (`bool`, *optional*, defaults to `False`):
623
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
624
+ cached versions if they exist.
625
+
626
+ proxies (`Dict[str, str]`, *optional*):
627
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
628
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
629
+ local_files_only (`bool`, *optional*, defaults to `False`):
630
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
631
+ won't be downloaded from the Hub.
632
+ token (`str` or *bool*, *optional*):
633
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
634
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
635
+ revision (`str`, *optional*, defaults to `"main"`):
636
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
637
+ allowed by Git.
638
+ subfolder (`str`, *optional*, defaults to `""`):
639
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
640
+ weight_name (`str`, *optional*, defaults to None):
641
+ Name of the serialized state dict file.
642
+ """
643
+ # Load the main state dict first which has the LoRA layers for either of
644
+ # UNet and text encoder or both.
645
+ cache_dir = kwargs.pop("cache_dir", None)
646
+ force_download = kwargs.pop("force_download", False)
647
+ proxies = kwargs.pop("proxies", None)
648
+ local_files_only = kwargs.pop("local_files_only", None)
649
+ token = kwargs.pop("token", None)
650
+ revision = kwargs.pop("revision", None)
651
+ subfolder = kwargs.pop("subfolder", None)
652
+ weight_name = kwargs.pop("weight_name", None)
653
+ unet_config = kwargs.pop("unet_config", None)
654
+ use_safetensors = kwargs.pop("use_safetensors", None)
655
+
656
+ allow_pickle = False
657
+ if use_safetensors is None:
658
+ use_safetensors = True
659
+ allow_pickle = True
660
+
661
+ user_agent = {
662
+ "file_type": "attn_procs_weights",
663
+ "framework": "pytorch",
664
+ }
665
+
666
+ state_dict = cls._fetch_state_dict(
667
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
668
+ weight_name=weight_name,
669
+ use_safetensors=use_safetensors,
670
+ local_files_only=local_files_only,
671
+ cache_dir=cache_dir,
672
+ force_download=force_download,
673
+ proxies=proxies,
674
+ token=token,
675
+ revision=revision,
676
+ subfolder=subfolder,
677
+ user_agent=user_agent,
678
+ allow_pickle=allow_pickle,
679
+ )
680
+
681
+ network_alphas = None
682
+ # TODO: replace it with a method from `state_dict_utils`
683
+ if all(
684
+ (
685
+ k.startswith("lora_te_")
686
+ or k.startswith("lora_unet_")
687
+ or k.startswith("lora_te1_")
688
+ or k.startswith("lora_te2_")
689
+ )
690
+ for k in state_dict.keys()
691
+ ):
692
+ # Map SDXL blocks correctly.
693
+ if unet_config is not None:
694
+ # use unet config to remap block numbers
695
+ state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
696
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
697
+
698
+ return state_dict, network_alphas
699
+
700
+ @classmethod
701
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
702
+ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
703
+ """
704
+ This will load the LoRA layers specified in `state_dict` into `unet`.
705
+
706
+ Parameters:
707
+ state_dict (`dict`):
708
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
709
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
710
+ encoder lora layers.
711
+ network_alphas (`Dict[str, float]`):
712
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
713
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
714
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
715
+ unet (`UNet2DConditionModel`):
716
+ The UNet model to load the LoRA layers into.
717
+ adapter_name (`str`, *optional*):
718
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
719
+ `default_{i}` where i is the total number of adapters being loaded.
720
+ """
721
+ if not USE_PEFT_BACKEND:
722
+ raise ValueError("PEFT backend is required for this method.")
723
+
724
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
725
+ # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
726
+ # their prefixes.
727
+ keys = list(state_dict.keys())
728
+ only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
729
+ if not only_text_encoder:
730
+ # Load the layers corresponding to UNet.
731
+ logger.info(f"Loading {cls.unet_name}.")
732
+ unet.load_attn_procs(
733
+ state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
734
+ )
735
+
736
+ @classmethod
737
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
738
+ def load_lora_into_text_encoder(
739
+ cls,
740
+ state_dict,
741
+ network_alphas,
742
+ text_encoder,
743
+ prefix=None,
744
+ lora_scale=1.0,
745
+ adapter_name=None,
746
+ _pipeline=None,
747
+ ):
748
+ """
749
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
750
+
751
+ Parameters:
752
+ state_dict (`dict`):
753
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
754
+ additional `text_encoder` to distinguish between unet lora layers.
755
+ network_alphas (`Dict[str, float]`):
756
+ See `LoRALinearLayer` for more details.
757
+ text_encoder (`CLIPTextModel`):
758
+ The text encoder model to load the LoRA layers into.
759
+ prefix (`str`):
760
+ Expected prefix of the `text_encoder` in the `state_dict`.
761
+ lora_scale (`float`):
762
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
763
+ lora layer.
764
+ adapter_name (`str`, *optional*):
765
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
766
+ `default_{i}` where i is the total number of adapters being loaded.
767
+ """
768
+ if not USE_PEFT_BACKEND:
769
+ raise ValueError("PEFT backend is required for this method.")
770
+
771
+ from peft import LoraConfig
772
+
773
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
774
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
775
+ # their prefixes.
776
+ keys = list(state_dict.keys())
777
+ prefix = cls.text_encoder_name if prefix is None else prefix
778
+
779
+ # Safe prefix to check with.
780
+ if any(cls.text_encoder_name in key for key in keys):
781
+ # Load the layers corresponding to text encoder and make necessary adjustments.
782
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
783
+ text_encoder_lora_state_dict = {
784
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
785
+ }
786
+
787
+ if len(text_encoder_lora_state_dict) > 0:
788
+ logger.info(f"Loading {prefix}.")
789
+ rank = {}
790
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
791
+
792
+ # convert state dict
793
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
794
+
795
+ for name, _ in text_encoder_attn_modules(text_encoder):
796
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
797
+ rank_key = f"{name}.{module}.lora_B.weight"
798
+ if rank_key not in text_encoder_lora_state_dict:
799
+ continue
800
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
801
+
802
+ for name, _ in text_encoder_mlp_modules(text_encoder):
803
+ for module in ("fc1", "fc2"):
804
+ rank_key = f"{name}.{module}.lora_B.weight"
805
+ if rank_key not in text_encoder_lora_state_dict:
806
+ continue
807
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
808
+
809
+ if network_alphas is not None:
810
+ alpha_keys = [
811
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
812
+ ]
813
+ network_alphas = {
814
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
815
+ }
816
+
817
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
818
+ if "use_dora" in lora_config_kwargs:
819
+ if lora_config_kwargs["use_dora"]:
820
+ if is_peft_version("<", "0.9.0"):
821
+ raise ValueError(
822
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
823
+ )
824
+ else:
825
+ if is_peft_version("<", "0.9.0"):
826
+ lora_config_kwargs.pop("use_dora")
827
+ lora_config = LoraConfig(**lora_config_kwargs)
828
+
829
+ # adapter_name
830
+ if adapter_name is None:
831
+ adapter_name = get_adapter_name(text_encoder)
832
+
833
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
834
+
835
+ # inject LoRA layers and load the state dict
836
+ # in transformers we automatically check whether the adapter name is already in use or not
837
+ text_encoder.load_adapter(
838
+ adapter_name=adapter_name,
839
+ adapter_state_dict=text_encoder_lora_state_dict,
840
+ peft_config=lora_config,
841
+ )
842
+
843
+ # scale LoRA layers with `lora_scale`
844
+ scale_lora_layers(text_encoder, weight=lora_scale)
845
+
846
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
847
+
848
+ # Offload back.
849
+ if is_model_cpu_offload:
850
+ _pipeline.enable_model_cpu_offload()
851
+ elif is_sequential_cpu_offload:
852
+ _pipeline.enable_sequential_cpu_offload()
853
+ # Unsafe code />
854
+
855
+ @classmethod
856
+ def save_lora_weights(
857
+ cls,
858
+ save_directory: Union[str, os.PathLike],
859
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
860
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
861
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
862
+ is_main_process: bool = True,
863
+ weight_name: str = None,
864
+ save_function: Callable = None,
865
+ safe_serialization: bool = True,
866
+ ):
867
+ r"""
868
+ Save the LoRA parameters corresponding to the UNet and text encoder.
869
+
870
+ Arguments:
871
+ save_directory (`str` or `os.PathLike`):
872
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
873
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
874
+ State dict of the LoRA layers corresponding to the `unet`.
875
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
876
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
877
+ encoder LoRA state dict because it comes from 🤗 Transformers.
878
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
879
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
880
+ encoder LoRA state dict because it comes from 🤗 Transformers.
881
+ is_main_process (`bool`, *optional*, defaults to `True`):
882
+ Whether the process calling this is the main process or not. Useful during distributed training and you
883
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
884
+ process to avoid race conditions.
885
+ save_function (`Callable`):
886
+ The function to use to save the state dictionary. Useful during distributed training when you need to
887
+ replace `torch.save` with another method. Can be configured with the environment variable
888
+ `DIFFUSERS_SAVE_MODE`.
889
+ safe_serialization (`bool`, *optional*, defaults to `True`):
890
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
891
+ """
892
+ state_dict = {}
893
+
894
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
895
+ raise ValueError(
896
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
897
+ )
898
+
899
+ if unet_lora_layers:
900
+ state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
901
+
902
+ if text_encoder_lora_layers:
903
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
904
+
905
+ if text_encoder_2_lora_layers:
906
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
907
+
908
+ cls.write_lora_layers(
909
+ state_dict=state_dict,
910
+ save_directory=save_directory,
911
+ is_main_process=is_main_process,
912
+ weight_name=weight_name,
913
+ save_function=save_function,
914
+ safe_serialization=safe_serialization,
915
+ )
916
+
917
+ def fuse_lora(
918
+ self,
919
+ components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
920
+ lora_scale: float = 1.0,
921
+ safe_fusing: bool = False,
922
+ adapter_names: Optional[List[str]] = None,
923
+ **kwargs,
924
+ ):
925
+ r"""
926
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
927
+
928
+ <Tip warning={true}>
929
+
930
+ This is an experimental API.
931
+
932
+ </Tip>
933
+
934
+ Args:
935
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
936
+ lora_scale (`float`, defaults to 1.0):
937
+ Controls how much to influence the outputs with the LoRA parameters.
938
+ safe_fusing (`bool`, defaults to `False`):
939
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
940
+ adapter_names (`List[str]`, *optional*):
941
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
942
+
943
+ Example:
944
+
945
+ ```py
946
+ from diffusers import DiffusionPipeline
947
+ import torch
948
+
949
+ pipeline = DiffusionPipeline.from_pretrained(
950
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
951
+ ).to("cuda")
952
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
953
+ pipeline.fuse_lora(lora_scale=0.7)
954
+ ```
955
+ """
956
+ super().fuse_lora(
957
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
958
+ )
959
+
960
+ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
961
+ r"""
962
+ Reverses the effect of
963
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
964
+
965
+ <Tip warning={true}>
966
+
967
+ This is an experimental API.
968
+
969
+ </Tip>
970
+
971
+ Args:
972
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
973
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
974
+ unfuse_text_encoder (`bool`, defaults to `True`):
975
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
976
+ LoRA parameters then it won't have any effect.
977
+ """
978
+ super().unfuse_lora(components=components)
979
+
980
+
981
+ class SD3LoraLoaderMixin(LoraBaseMixin):
982
+ r"""
983
+ Load LoRA layers into [`SD3Transformer2DModel`],
984
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
985
+ [`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
986
+
987
+ Specific to [`StableDiffusion3Pipeline`].
988
+ """
989
+
990
+ _lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
991
+ transformer_name = TRANSFORMER_NAME
992
+ text_encoder_name = TEXT_ENCODER_NAME
993
+
994
+ @classmethod
995
+ @validate_hf_hub_args
996
+ def lora_state_dict(
997
+ cls,
998
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
999
+ **kwargs,
1000
+ ):
1001
+ r"""
1002
+ Return state dict for lora weights and the network alphas.
1003
+
1004
+ <Tip warning={true}>
1005
+
1006
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1007
+
1008
+ This function is experimental and might change in the future.
1009
+
1010
+ </Tip>
1011
+
1012
+ Parameters:
1013
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1014
+ Can be either:
1015
+
1016
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1017
+ the Hub.
1018
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1019
+ with [`ModelMixin.save_pretrained`].
1020
+ - A [torch state
1021
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1022
+
1023
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1024
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1025
+ is not used.
1026
+ force_download (`bool`, *optional*, defaults to `False`):
1027
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1028
+ cached versions if they exist.
1029
+
1030
+ proxies (`Dict[str, str]`, *optional*):
1031
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1032
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1033
+ local_files_only (`bool`, *optional*, defaults to `False`):
1034
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1035
+ won't be downloaded from the Hub.
1036
+ token (`str` or *bool*, *optional*):
1037
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1038
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1039
+ revision (`str`, *optional*, defaults to `"main"`):
1040
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1041
+ allowed by Git.
1042
+ subfolder (`str`, *optional*, defaults to `""`):
1043
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1044
+
1045
+ """
1046
+ # Load the main state dict first which has the LoRA layers for either of
1047
+ # transformer and text encoder or both.
1048
+ cache_dir = kwargs.pop("cache_dir", None)
1049
+ force_download = kwargs.pop("force_download", False)
1050
+ proxies = kwargs.pop("proxies", None)
1051
+ local_files_only = kwargs.pop("local_files_only", None)
1052
+ token = kwargs.pop("token", None)
1053
+ revision = kwargs.pop("revision", None)
1054
+ subfolder = kwargs.pop("subfolder", None)
1055
+ weight_name = kwargs.pop("weight_name", None)
1056
+ use_safetensors = kwargs.pop("use_safetensors", None)
1057
+
1058
+ allow_pickle = False
1059
+ if use_safetensors is None:
1060
+ use_safetensors = True
1061
+ allow_pickle = True
1062
+
1063
+ user_agent = {
1064
+ "file_type": "attn_procs_weights",
1065
+ "framework": "pytorch",
1066
+ }
1067
+
1068
+ state_dict = cls._fetch_state_dict(
1069
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1070
+ weight_name=weight_name,
1071
+ use_safetensors=use_safetensors,
1072
+ local_files_only=local_files_only,
1073
+ cache_dir=cache_dir,
1074
+ force_download=force_download,
1075
+ proxies=proxies,
1076
+ token=token,
1077
+ revision=revision,
1078
+ subfolder=subfolder,
1079
+ user_agent=user_agent,
1080
+ allow_pickle=allow_pickle,
1081
+ )
1082
+
1083
+ return state_dict
1084
+
1085
+ def load_lora_weights(
1086
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1087
+ ):
1088
+ """
1089
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
1090
+ `self.text_encoder`.
1091
+
1092
+ All kwargs are forwarded to `self.lora_state_dict`.
1093
+
1094
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1095
+ loaded.
1096
+
1097
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1098
+ dict is loaded into `self.transformer`.
1099
+
1100
+ Parameters:
1101
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1102
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1103
+ kwargs (`dict`, *optional*):
1104
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1105
+ adapter_name (`str`, *optional*):
1106
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1107
+ `default_{i}` where i is the total number of adapters being loaded.
1108
+ """
1109
+ if not USE_PEFT_BACKEND:
1110
+ raise ValueError("PEFT backend is required for this method.")
1111
+
1112
+ # if a dict is passed, copy it instead of modifying it inplace
1113
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1114
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1115
+
1116
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1117
+ state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1118
+
1119
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1120
+ if not is_correct_format:
1121
+ raise ValueError("Invalid LoRA checkpoint.")
1122
+
1123
+ self.load_lora_into_transformer(
1124
+ state_dict,
1125
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1126
+ adapter_name=adapter_name,
1127
+ _pipeline=self,
1128
+ )
1129
+
1130
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1131
+ if len(text_encoder_state_dict) > 0:
1132
+ self.load_lora_into_text_encoder(
1133
+ text_encoder_state_dict,
1134
+ network_alphas=None,
1135
+ text_encoder=self.text_encoder,
1136
+ prefix="text_encoder",
1137
+ lora_scale=self.lora_scale,
1138
+ adapter_name=adapter_name,
1139
+ _pipeline=self,
1140
+ )
1141
+
1142
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1143
+ if len(text_encoder_2_state_dict) > 0:
1144
+ self.load_lora_into_text_encoder(
1145
+ text_encoder_2_state_dict,
1146
+ network_alphas=None,
1147
+ text_encoder=self.text_encoder_2,
1148
+ prefix="text_encoder_2",
1149
+ lora_scale=self.lora_scale,
1150
+ adapter_name=adapter_name,
1151
+ _pipeline=self,
1152
+ )
1153
+
1154
+ @classmethod
1155
+ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
1156
+ """
1157
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1158
+
1159
+ Parameters:
1160
+ state_dict (`dict`):
1161
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1162
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1163
+ encoder lora layers.
1164
+ transformer (`SD3Transformer2DModel`):
1165
+ The Transformer model to load the LoRA layers into.
1166
+ adapter_name (`str`, *optional*):
1167
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1168
+ `default_{i}` where i is the total number of adapters being loaded.
1169
+ """
1170
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1171
+
1172
+ keys = list(state_dict.keys())
1173
+
1174
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1175
+ state_dict = {
1176
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1177
+ }
1178
+
1179
+ if len(state_dict.keys()) > 0:
1180
+ # check with first key if is not in peft format
1181
+ first_key = next(iter(state_dict.keys()))
1182
+ if "lora_A" not in first_key:
1183
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
1184
+
1185
+ if adapter_name in getattr(transformer, "peft_config", {}):
1186
+ raise ValueError(
1187
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1188
+ )
1189
+
1190
+ rank = {}
1191
+ for key, val in state_dict.items():
1192
+ if "lora_B" in key:
1193
+ rank[key] = val.shape[1]
1194
+
1195
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
1196
+ if "use_dora" in lora_config_kwargs:
1197
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1198
+ raise ValueError(
1199
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1200
+ )
1201
+ else:
1202
+ lora_config_kwargs.pop("use_dora")
1203
+ lora_config = LoraConfig(**lora_config_kwargs)
1204
+
1205
+ # adapter_name
1206
+ if adapter_name is None:
1207
+ adapter_name = get_adapter_name(transformer)
1208
+
1209
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1210
+ # otherwise loading LoRA weights will lead to an error
1211
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1212
+
1213
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1214
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1215
+
1216
+ if incompatible_keys is not None:
1217
+ # check only for unexpected keys
1218
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1219
+ if unexpected_keys:
1220
+ logger.warning(
1221
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1222
+ f" {unexpected_keys}. "
1223
+ )
1224
+
1225
+ # Offload back.
1226
+ if is_model_cpu_offload:
1227
+ _pipeline.enable_model_cpu_offload()
1228
+ elif is_sequential_cpu_offload:
1229
+ _pipeline.enable_sequential_cpu_offload()
1230
+ # Unsafe code />
1231
+
1232
+ @classmethod
1233
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1234
+ def load_lora_into_text_encoder(
1235
+ cls,
1236
+ state_dict,
1237
+ network_alphas,
1238
+ text_encoder,
1239
+ prefix=None,
1240
+ lora_scale=1.0,
1241
+ adapter_name=None,
1242
+ _pipeline=None,
1243
+ ):
1244
+ """
1245
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
1246
+
1247
+ Parameters:
1248
+ state_dict (`dict`):
1249
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
1250
+ additional `text_encoder` to distinguish between unet lora layers.
1251
+ network_alphas (`Dict[str, float]`):
1252
+ See `LoRALinearLayer` for more details.
1253
+ text_encoder (`CLIPTextModel`):
1254
+ The text encoder model to load the LoRA layers into.
1255
+ prefix (`str`):
1256
+ Expected prefix of the `text_encoder` in the `state_dict`.
1257
+ lora_scale (`float`):
1258
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
1259
+ lora layer.
1260
+ adapter_name (`str`, *optional*):
1261
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1262
+ `default_{i}` where i is the total number of adapters being loaded.
1263
+ """
1264
+ if not USE_PEFT_BACKEND:
1265
+ raise ValueError("PEFT backend is required for this method.")
1266
+
1267
+ from peft import LoraConfig
1268
+
1269
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
1270
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
1271
+ # their prefixes.
1272
+ keys = list(state_dict.keys())
1273
+ prefix = cls.text_encoder_name if prefix is None else prefix
1274
+
1275
+ # Safe prefix to check with.
1276
+ if any(cls.text_encoder_name in key for key in keys):
1277
+ # Load the layers corresponding to text encoder and make necessary adjustments.
1278
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
1279
+ text_encoder_lora_state_dict = {
1280
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
1281
+ }
1282
+
1283
+ if len(text_encoder_lora_state_dict) > 0:
1284
+ logger.info(f"Loading {prefix}.")
1285
+ rank = {}
1286
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
1287
+
1288
+ # convert state dict
1289
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
1290
+
1291
+ for name, _ in text_encoder_attn_modules(text_encoder):
1292
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
1293
+ rank_key = f"{name}.{module}.lora_B.weight"
1294
+ if rank_key not in text_encoder_lora_state_dict:
1295
+ continue
1296
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1297
+
1298
+ for name, _ in text_encoder_mlp_modules(text_encoder):
1299
+ for module in ("fc1", "fc2"):
1300
+ rank_key = f"{name}.{module}.lora_B.weight"
1301
+ if rank_key not in text_encoder_lora_state_dict:
1302
+ continue
1303
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1304
+
1305
+ if network_alphas is not None:
1306
+ alpha_keys = [
1307
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1308
+ ]
1309
+ network_alphas = {
1310
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1311
+ }
1312
+
1313
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1314
+ if "use_dora" in lora_config_kwargs:
1315
+ if lora_config_kwargs["use_dora"]:
1316
+ if is_peft_version("<", "0.9.0"):
1317
+ raise ValueError(
1318
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1319
+ )
1320
+ else:
1321
+ if is_peft_version("<", "0.9.0"):
1322
+ lora_config_kwargs.pop("use_dora")
1323
+ lora_config = LoraConfig(**lora_config_kwargs)
1324
+
1325
+ # adapter_name
1326
+ if adapter_name is None:
1327
+ adapter_name = get_adapter_name(text_encoder)
1328
+
1329
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1330
+
1331
+ # inject LoRA layers and load the state dict
1332
+ # in transformers we automatically check whether the adapter name is already in use or not
1333
+ text_encoder.load_adapter(
1334
+ adapter_name=adapter_name,
1335
+ adapter_state_dict=text_encoder_lora_state_dict,
1336
+ peft_config=lora_config,
1337
+ )
1338
+
1339
+ # scale LoRA layers with `lora_scale`
1340
+ scale_lora_layers(text_encoder, weight=lora_scale)
1341
+
1342
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
1343
+
1344
+ # Offload back.
1345
+ if is_model_cpu_offload:
1346
+ _pipeline.enable_model_cpu_offload()
1347
+ elif is_sequential_cpu_offload:
1348
+ _pipeline.enable_sequential_cpu_offload()
1349
+ # Unsafe code />
1350
+
1351
+ @classmethod
1352
+ def save_lora_weights(
1353
+ cls,
1354
+ save_directory: Union[str, os.PathLike],
1355
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
1356
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1357
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1358
+ is_main_process: bool = True,
1359
+ weight_name: str = None,
1360
+ save_function: Callable = None,
1361
+ safe_serialization: bool = True,
1362
+ ):
1363
+ r"""
1364
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1365
+
1366
+ Arguments:
1367
+ save_directory (`str` or `os.PathLike`):
1368
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1369
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1370
+ State dict of the LoRA layers corresponding to the `transformer`.
1371
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1372
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1373
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1374
+ text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1375
+ State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
1376
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1377
+ is_main_process (`bool`, *optional*, defaults to `True`):
1378
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1379
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1380
+ process to avoid race conditions.
1381
+ save_function (`Callable`):
1382
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1383
+ replace `torch.save` with another method. Can be configured with the environment variable
1384
+ `DIFFUSERS_SAVE_MODE`.
1385
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1386
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1387
+ """
1388
+ state_dict = {}
1389
+
1390
+ if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
1391
+ raise ValueError(
1392
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
1393
+ )
1394
+
1395
+ if transformer_lora_layers:
1396
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1397
+
1398
+ if text_encoder_lora_layers:
1399
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
1400
+
1401
+ if text_encoder_2_lora_layers:
1402
+ state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1403
+
1404
+ # Save the model
1405
+ cls.write_lora_layers(
1406
+ state_dict=state_dict,
1407
+ save_directory=save_directory,
1408
+ is_main_process=is_main_process,
1409
+ weight_name=weight_name,
1410
+ save_function=save_function,
1411
+ safe_serialization=safe_serialization,
1412
+ )
1413
+
1414
+ def fuse_lora(
1415
+ self,
1416
+ components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
1417
+ lora_scale: float = 1.0,
1418
+ safe_fusing: bool = False,
1419
+ adapter_names: Optional[List[str]] = None,
1420
+ **kwargs,
1421
+ ):
1422
+ r"""
1423
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1424
+
1425
+ <Tip warning={true}>
1426
+
1427
+ This is an experimental API.
1428
+
1429
+ </Tip>
1430
+
1431
+ Args:
1432
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1433
+ lora_scale (`float`, defaults to 1.0):
1434
+ Controls how much to influence the outputs with the LoRA parameters.
1435
+ safe_fusing (`bool`, defaults to `False`):
1436
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1437
+ adapter_names (`List[str]`, *optional*):
1438
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1439
+
1440
+ Example:
1441
+
1442
+ ```py
1443
+ from diffusers import DiffusionPipeline
1444
+ import torch
1445
+
1446
+ pipeline = DiffusionPipeline.from_pretrained(
1447
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1448
+ ).to("cuda")
1449
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1450
+ pipeline.fuse_lora(lora_scale=0.7)
1451
+ ```
1452
+ """
1453
+ super().fuse_lora(
1454
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1455
+ )
1456
+
1457
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
1458
+ r"""
1459
+ Reverses the effect of
1460
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1461
+
1462
+ <Tip warning={true}>
1463
+
1464
+ This is an experimental API.
1465
+
1466
+ </Tip>
1467
+
1468
+ Args:
1469
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1470
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1471
+ unfuse_text_encoder (`bool`, defaults to `True`):
1472
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1473
+ LoRA parameters then it won't have any effect.
1474
+ """
1475
+ super().unfuse_lora(components=components)
1476
+
1477
+
1478
+ class FluxLoraLoaderMixin(LoraBaseMixin):
1479
+ r"""
1480
+ Load LoRA layers into [`FluxTransformer2DModel`],
1481
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
1482
+
1483
+ Specific to [`StableDiffusion3Pipeline`].
1484
+ """
1485
+
1486
+ _lora_loadable_modules = ["transformer", "text_encoder"]
1487
+ transformer_name = TRANSFORMER_NAME
1488
+ text_encoder_name = TEXT_ENCODER_NAME
1489
+
1490
+ @classmethod
1491
+ @validate_hf_hub_args
1492
+ def lora_state_dict(
1493
+ cls,
1494
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
1495
+ return_alphas: bool = False,
1496
+ **kwargs,
1497
+ ):
1498
+ r"""
1499
+ Return state dict for lora weights and the network alphas.
1500
+
1501
+ <Tip warning={true}>
1502
+
1503
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
1504
+
1505
+ This function is experimental and might change in the future.
1506
+
1507
+ </Tip>
1508
+
1509
+ Parameters:
1510
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1511
+ Can be either:
1512
+
1513
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1514
+ the Hub.
1515
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1516
+ with [`ModelMixin.save_pretrained`].
1517
+ - A [torch state
1518
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
1519
+
1520
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1521
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1522
+ is not used.
1523
+ force_download (`bool`, *optional*, defaults to `False`):
1524
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1525
+ cached versions if they exist.
1526
+
1527
+ proxies (`Dict[str, str]`, *optional*):
1528
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1529
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1530
+ local_files_only (`bool`, *optional*, defaults to `False`):
1531
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1532
+ won't be downloaded from the Hub.
1533
+ token (`str` or *bool*, *optional*):
1534
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1535
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1536
+ revision (`str`, *optional*, defaults to `"main"`):
1537
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1538
+ allowed by Git.
1539
+ subfolder (`str`, *optional*, defaults to `""`):
1540
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1541
+
1542
+ """
1543
+ # Load the main state dict first which has the LoRA layers for either of
1544
+ # transformer and text encoder or both.
1545
+ cache_dir = kwargs.pop("cache_dir", None)
1546
+ force_download = kwargs.pop("force_download", False)
1547
+ proxies = kwargs.pop("proxies", None)
1548
+ local_files_only = kwargs.pop("local_files_only", None)
1549
+ token = kwargs.pop("token", None)
1550
+ revision = kwargs.pop("revision", None)
1551
+ subfolder = kwargs.pop("subfolder", None)
1552
+ weight_name = kwargs.pop("weight_name", None)
1553
+ use_safetensors = kwargs.pop("use_safetensors", None)
1554
+
1555
+ allow_pickle = False
1556
+ if use_safetensors is None:
1557
+ use_safetensors = True
1558
+ allow_pickle = True
1559
+
1560
+ user_agent = {
1561
+ "file_type": "attn_procs_weights",
1562
+ "framework": "pytorch",
1563
+ }
1564
+
1565
+ state_dict = cls._fetch_state_dict(
1566
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
1567
+ weight_name=weight_name,
1568
+ use_safetensors=use_safetensors,
1569
+ local_files_only=local_files_only,
1570
+ cache_dir=cache_dir,
1571
+ force_download=force_download,
1572
+ proxies=proxies,
1573
+ token=token,
1574
+ revision=revision,
1575
+ subfolder=subfolder,
1576
+ user_agent=user_agent,
1577
+ allow_pickle=allow_pickle,
1578
+ )
1579
+
1580
+ # For state dicts like
1581
+ # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
1582
+ keys = list(state_dict.keys())
1583
+ network_alphas = {}
1584
+ for k in keys:
1585
+ if "alpha" in k:
1586
+ alpha_value = state_dict.get(k)
1587
+ if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
1588
+ alpha_value, float
1589
+ ):
1590
+ network_alphas[k] = state_dict.pop(k)
1591
+ else:
1592
+ raise ValueError(
1593
+ f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
1594
+ )
1595
+
1596
+ if return_alphas:
1597
+ return state_dict, network_alphas
1598
+ else:
1599
+ return state_dict
1600
+
1601
+ def load_lora_weights(
1602
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
1603
+ ):
1604
+ """
1605
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
1606
+ `self.text_encoder`.
1607
+
1608
+ All kwargs are forwarded to `self.lora_state_dict`.
1609
+
1610
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
1611
+ loaded.
1612
+
1613
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
1614
+ dict is loaded into `self.transformer`.
1615
+
1616
+ Parameters:
1617
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
1618
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1619
+ kwargs (`dict`, *optional*):
1620
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1621
+ adapter_name (`str`, *optional*):
1622
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1623
+ `default_{i}` where i is the total number of adapters being loaded.
1624
+ """
1625
+ if not USE_PEFT_BACKEND:
1626
+ raise ValueError("PEFT backend is required for this method.")
1627
+
1628
+ # if a dict is passed, copy it instead of modifying it inplace
1629
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
1630
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
1631
+
1632
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1633
+ state_dict, network_alphas = self.lora_state_dict(
1634
+ pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
1635
+ )
1636
+
1637
+ is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1638
+ if not is_correct_format:
1639
+ raise ValueError("Invalid LoRA checkpoint.")
1640
+
1641
+ self.load_lora_into_transformer(
1642
+ state_dict,
1643
+ network_alphas=network_alphas,
1644
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1645
+ adapter_name=adapter_name,
1646
+ _pipeline=self,
1647
+ )
1648
+
1649
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1650
+ if len(text_encoder_state_dict) > 0:
1651
+ self.load_lora_into_text_encoder(
1652
+ text_encoder_state_dict,
1653
+ network_alphas=network_alphas,
1654
+ text_encoder=self.text_encoder,
1655
+ prefix="text_encoder",
1656
+ lora_scale=self.lora_scale,
1657
+ adapter_name=adapter_name,
1658
+ _pipeline=self,
1659
+ )
1660
+
1661
+ @classmethod
1662
+ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
1663
+ """
1664
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1665
+
1666
+ Parameters:
1667
+ state_dict (`dict`):
1668
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1669
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1670
+ encoder lora layers.
1671
+ network_alphas (`Dict[str, float]`):
1672
+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1673
+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1674
+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1675
+ transformer (`SD3Transformer2DModel`):
1676
+ The Transformer model to load the LoRA layers into.
1677
+ adapter_name (`str`, *optional*):
1678
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1679
+ `default_{i}` where i is the total number of adapters being loaded.
1680
+ """
1681
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
1682
+
1683
+ keys = list(state_dict.keys())
1684
+
1685
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
1686
+ state_dict = {
1687
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
1688
+ }
1689
+
1690
+ if len(state_dict.keys()) > 0:
1691
+ # check with first key if is not in peft format
1692
+ first_key = next(iter(state_dict.keys()))
1693
+ if "lora_A" not in first_key:
1694
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
1695
+
1696
+ if adapter_name in getattr(transformer, "peft_config", {}):
1697
+ raise ValueError(
1698
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
1699
+ )
1700
+
1701
+ rank = {}
1702
+ for key, val in state_dict.items():
1703
+ if "lora_B" in key:
1704
+ rank[key] = val.shape[1]
1705
+
1706
+ if network_alphas is not None and len(network_alphas) >= 1:
1707
+ prefix = cls.transformer_name
1708
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
1709
+ network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
1710
+
1711
+ lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
1712
+ if "use_dora" in lora_config_kwargs:
1713
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
1714
+ raise ValueError(
1715
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1716
+ )
1717
+ else:
1718
+ lora_config_kwargs.pop("use_dora")
1719
+ lora_config = LoraConfig(**lora_config_kwargs)
1720
+
1721
+ # adapter_name
1722
+ if adapter_name is None:
1723
+ adapter_name = get_adapter_name(transformer)
1724
+
1725
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1726
+ # otherwise loading LoRA weights will lead to an error
1727
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1728
+
1729
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
1730
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
1731
+
1732
+ if incompatible_keys is not None:
1733
+ # check only for unexpected keys
1734
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1735
+ if unexpected_keys:
1736
+ logger.warning(
1737
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1738
+ f" {unexpected_keys}. "
1739
+ )
1740
+
1741
+ # Offload back.
1742
+ if is_model_cpu_offload:
1743
+ _pipeline.enable_model_cpu_offload()
1744
+ elif is_sequential_cpu_offload:
1745
+ _pipeline.enable_sequential_cpu_offload()
1746
+ # Unsafe code />
1747
+
1748
+ @classmethod
1749
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1750
+ def load_lora_into_text_encoder(
1751
+ cls,
1752
+ state_dict,
1753
+ network_alphas,
1754
+ text_encoder,
1755
+ prefix=None,
1756
+ lora_scale=1.0,
1757
+ adapter_name=None,
1758
+ _pipeline=None,
1759
+ ):
1760
+ """
1761
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
1762
+
1763
+ Parameters:
1764
+ state_dict (`dict`):
1765
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
1766
+ additional `text_encoder` to distinguish between unet lora layers.
1767
+ network_alphas (`Dict[str, float]`):
1768
+ See `LoRALinearLayer` for more details.
1769
+ text_encoder (`CLIPTextModel`):
1770
+ The text encoder model to load the LoRA layers into.
1771
+ prefix (`str`):
1772
+ Expected prefix of the `text_encoder` in the `state_dict`.
1773
+ lora_scale (`float`):
1774
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
1775
+ lora layer.
1776
+ adapter_name (`str`, *optional*):
1777
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1778
+ `default_{i}` where i is the total number of adapters being loaded.
1779
+ """
1780
+ if not USE_PEFT_BACKEND:
1781
+ raise ValueError("PEFT backend is required for this method.")
1782
+
1783
+ from peft import LoraConfig
1784
+
1785
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
1786
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
1787
+ # their prefixes.
1788
+ keys = list(state_dict.keys())
1789
+ prefix = cls.text_encoder_name if prefix is None else prefix
1790
+
1791
+ # Safe prefix to check with.
1792
+ if any(cls.text_encoder_name in key for key in keys):
1793
+ # Load the layers corresponding to text encoder and make necessary adjustments.
1794
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
1795
+ text_encoder_lora_state_dict = {
1796
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
1797
+ }
1798
+
1799
+ if len(text_encoder_lora_state_dict) > 0:
1800
+ logger.info(f"Loading {prefix}.")
1801
+ rank = {}
1802
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
1803
+
1804
+ # convert state dict
1805
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
1806
+
1807
+ for name, _ in text_encoder_attn_modules(text_encoder):
1808
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
1809
+ rank_key = f"{name}.{module}.lora_B.weight"
1810
+ if rank_key not in text_encoder_lora_state_dict:
1811
+ continue
1812
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1813
+
1814
+ for name, _ in text_encoder_mlp_modules(text_encoder):
1815
+ for module in ("fc1", "fc2"):
1816
+ rank_key = f"{name}.{module}.lora_B.weight"
1817
+ if rank_key not in text_encoder_lora_state_dict:
1818
+ continue
1819
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1820
+
1821
+ if network_alphas is not None:
1822
+ alpha_keys = [
1823
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1824
+ ]
1825
+ network_alphas = {
1826
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1827
+ }
1828
+
1829
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1830
+ if "use_dora" in lora_config_kwargs:
1831
+ if lora_config_kwargs["use_dora"]:
1832
+ if is_peft_version("<", "0.9.0"):
1833
+ raise ValueError(
1834
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1835
+ )
1836
+ else:
1837
+ if is_peft_version("<", "0.9.0"):
1838
+ lora_config_kwargs.pop("use_dora")
1839
+ lora_config = LoraConfig(**lora_config_kwargs)
1840
+
1841
+ # adapter_name
1842
+ if adapter_name is None:
1843
+ adapter_name = get_adapter_name(text_encoder)
1844
+
1845
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
1846
+
1847
+ # inject LoRA layers and load the state dict
1848
+ # in transformers we automatically check whether the adapter name is already in use or not
1849
+ text_encoder.load_adapter(
1850
+ adapter_name=adapter_name,
1851
+ adapter_state_dict=text_encoder_lora_state_dict,
1852
+ peft_config=lora_config,
1853
+ )
1854
+
1855
+ # scale LoRA layers with `lora_scale`
1856
+ scale_lora_layers(text_encoder, weight=lora_scale)
1857
+
1858
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
1859
+
1860
+ # Offload back.
1861
+ if is_model_cpu_offload:
1862
+ _pipeline.enable_model_cpu_offload()
1863
+ elif is_sequential_cpu_offload:
1864
+ _pipeline.enable_sequential_cpu_offload()
1865
+ # Unsafe code />
1866
+
1867
+ @classmethod
1868
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
1869
+ def save_lora_weights(
1870
+ cls,
1871
+ save_directory: Union[str, os.PathLike],
1872
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1873
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1874
+ is_main_process: bool = True,
1875
+ weight_name: str = None,
1876
+ save_function: Callable = None,
1877
+ safe_serialization: bool = True,
1878
+ ):
1879
+ r"""
1880
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1881
+
1882
+ Arguments:
1883
+ save_directory (`str` or `os.PathLike`):
1884
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1885
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1886
+ State dict of the LoRA layers corresponding to the `transformer`.
1887
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1888
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1889
+ encoder LoRA state dict because it comes from 🤗 Transformers.
1890
+ is_main_process (`bool`, *optional*, defaults to `True`):
1891
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1892
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1893
+ process to avoid race conditions.
1894
+ save_function (`Callable`):
1895
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1896
+ replace `torch.save` with another method. Can be configured with the environment variable
1897
+ `DIFFUSERS_SAVE_MODE`.
1898
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1899
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1900
+ """
1901
+ state_dict = {}
1902
+
1903
+ if not (transformer_lora_layers or text_encoder_lora_layers):
1904
+ raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
1905
+
1906
+ if transformer_lora_layers:
1907
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
1908
+
1909
+ if text_encoder_lora_layers:
1910
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
1911
+
1912
+ # Save the model
1913
+ cls.write_lora_layers(
1914
+ state_dict=state_dict,
1915
+ save_directory=save_directory,
1916
+ is_main_process=is_main_process,
1917
+ weight_name=weight_name,
1918
+ save_function=save_function,
1919
+ safe_serialization=safe_serialization,
1920
+ )
1921
+
1922
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
1923
+ def fuse_lora(
1924
+ self,
1925
+ components: List[str] = ["transformer", "text_encoder"],
1926
+ lora_scale: float = 1.0,
1927
+ safe_fusing: bool = False,
1928
+ adapter_names: Optional[List[str]] = None,
1929
+ **kwargs,
1930
+ ):
1931
+ r"""
1932
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1933
+
1934
+ <Tip warning={true}>
1935
+
1936
+ This is an experimental API.
1937
+
1938
+ </Tip>
1939
+
1940
+ Args:
1941
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
1942
+ lora_scale (`float`, defaults to 1.0):
1943
+ Controls how much to influence the outputs with the LoRA parameters.
1944
+ safe_fusing (`bool`, defaults to `False`):
1945
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
1946
+ adapter_names (`List[str]`, *optional*):
1947
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
1948
+
1949
+ Example:
1950
+
1951
+ ```py
1952
+ from diffusers import DiffusionPipeline
1953
+ import torch
1954
+
1955
+ pipeline = DiffusionPipeline.from_pretrained(
1956
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
1957
+ ).to("cuda")
1958
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
1959
+ pipeline.fuse_lora(lora_scale=0.7)
1960
+ ```
1961
+ """
1962
+ super().fuse_lora(
1963
+ components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
1964
+ )
1965
+
1966
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
1967
+ r"""
1968
+ Reverses the effect of
1969
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
1970
+
1971
+ <Tip warning={true}>
1972
+
1973
+ This is an experimental API.
1974
+
1975
+ </Tip>
1976
+
1977
+ Args:
1978
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
1979
+ """
1980
+ super().unfuse_lora(components=components)
1981
+
1982
+
1983
+ # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
1984
+ # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
1985
+ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
1986
+ _lora_loadable_modules = ["transformer", "text_encoder"]
1987
+ transformer_name = TRANSFORMER_NAME
1988
+ text_encoder_name = TEXT_ENCODER_NAME
1989
+
1990
+ @classmethod
1991
+ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
1992
+ """
1993
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
1994
+
1995
+ Parameters:
1996
+ state_dict (`dict`):
1997
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
1998
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
1999
+ encoder lora layers.
2000
+ network_alphas (`Dict[str, float]`):
2001
+ See `LoRALinearLayer` for more details.
2002
+ unet (`UNet2DConditionModel`):
2003
+ The UNet model to load the LoRA layers into.
2004
+ adapter_name (`str`, *optional*):
2005
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2006
+ `default_{i}` where i is the total number of adapters being loaded.
2007
+ """
2008
+ if not USE_PEFT_BACKEND:
2009
+ raise ValueError("PEFT backend is required for this method.")
2010
+
2011
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
2012
+
2013
+ keys = list(state_dict.keys())
2014
+
2015
+ transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
2016
+ state_dict = {
2017
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
2018
+ }
2019
+
2020
+ if network_alphas is not None:
2021
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
2022
+ network_alphas = {
2023
+ k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2024
+ }
2025
+
2026
+ if len(state_dict.keys()) > 0:
2027
+ if adapter_name in getattr(transformer, "peft_config", {}):
2028
+ raise ValueError(
2029
+ f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
2030
+ )
2031
+
2032
+ rank = {}
2033
+ for key, val in state_dict.items():
2034
+ if "lora_B" in key:
2035
+ rank[key] = val.shape[1]
2036
+
2037
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
2038
+ if "use_dora" in lora_config_kwargs:
2039
+ if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
2040
+ raise ValueError(
2041
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2042
+ )
2043
+ else:
2044
+ lora_config_kwargs.pop("use_dora")
2045
+ lora_config = LoraConfig(**lora_config_kwargs)
2046
+
2047
+ # adapter_name
2048
+ if adapter_name is None:
2049
+ adapter_name = get_adapter_name(transformer)
2050
+
2051
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
2052
+ # otherwise loading LoRA weights will lead to an error
2053
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2054
+
2055
+ inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
2056
+ incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
2057
+
2058
+ if incompatible_keys is not None:
2059
+ # check only for unexpected keys
2060
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
2061
+ if unexpected_keys:
2062
+ logger.warning(
2063
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
2064
+ f" {unexpected_keys}. "
2065
+ )
2066
+
2067
+ # Offload back.
2068
+ if is_model_cpu_offload:
2069
+ _pipeline.enable_model_cpu_offload()
2070
+ elif is_sequential_cpu_offload:
2071
+ _pipeline.enable_sequential_cpu_offload()
2072
+ # Unsafe code />
2073
+
2074
+ @classmethod
2075
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
2076
+ def load_lora_into_text_encoder(
2077
+ cls,
2078
+ state_dict,
2079
+ network_alphas,
2080
+ text_encoder,
2081
+ prefix=None,
2082
+ lora_scale=1.0,
2083
+ adapter_name=None,
2084
+ _pipeline=None,
2085
+ ):
2086
+ """
2087
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
2088
+
2089
+ Parameters:
2090
+ state_dict (`dict`):
2091
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
2092
+ additional `text_encoder` to distinguish between unet lora layers.
2093
+ network_alphas (`Dict[str, float]`):
2094
+ See `LoRALinearLayer` for more details.
2095
+ text_encoder (`CLIPTextModel`):
2096
+ The text encoder model to load the LoRA layers into.
2097
+ prefix (`str`):
2098
+ Expected prefix of the `text_encoder` in the `state_dict`.
2099
+ lora_scale (`float`):
2100
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
2101
+ lora layer.
2102
+ adapter_name (`str`, *optional*):
2103
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
2104
+ `default_{i}` where i is the total number of adapters being loaded.
2105
+ """
2106
+ if not USE_PEFT_BACKEND:
2107
+ raise ValueError("PEFT backend is required for this method.")
2108
+
2109
+ from peft import LoraConfig
2110
+
2111
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
2112
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
2113
+ # their prefixes.
2114
+ keys = list(state_dict.keys())
2115
+ prefix = cls.text_encoder_name if prefix is None else prefix
2116
+
2117
+ # Safe prefix to check with.
2118
+ if any(cls.text_encoder_name in key for key in keys):
2119
+ # Load the layers corresponding to text encoder and make necessary adjustments.
2120
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
2121
+ text_encoder_lora_state_dict = {
2122
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
2123
+ }
2124
+
2125
+ if len(text_encoder_lora_state_dict) > 0:
2126
+ logger.info(f"Loading {prefix}.")
2127
+ rank = {}
2128
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
2129
+
2130
+ # convert state dict
2131
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
2132
+
2133
+ for name, _ in text_encoder_attn_modules(text_encoder):
2134
+ for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
2135
+ rank_key = f"{name}.{module}.lora_B.weight"
2136
+ if rank_key not in text_encoder_lora_state_dict:
2137
+ continue
2138
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2139
+
2140
+ for name, _ in text_encoder_mlp_modules(text_encoder):
2141
+ for module in ("fc1", "fc2"):
2142
+ rank_key = f"{name}.{module}.lora_B.weight"
2143
+ if rank_key not in text_encoder_lora_state_dict:
2144
+ continue
2145
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
2146
+
2147
+ if network_alphas is not None:
2148
+ alpha_keys = [
2149
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
2150
+ ]
2151
+ network_alphas = {
2152
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
2153
+ }
2154
+
2155
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
2156
+ if "use_dora" in lora_config_kwargs:
2157
+ if lora_config_kwargs["use_dora"]:
2158
+ if is_peft_version("<", "0.9.0"):
2159
+ raise ValueError(
2160
+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
2161
+ )
2162
+ else:
2163
+ if is_peft_version("<", "0.9.0"):
2164
+ lora_config_kwargs.pop("use_dora")
2165
+ lora_config = LoraConfig(**lora_config_kwargs)
2166
+
2167
+ # adapter_name
2168
+ if adapter_name is None:
2169
+ adapter_name = get_adapter_name(text_encoder)
2170
+
2171
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2172
+
2173
+ # inject LoRA layers and load the state dict
2174
+ # in transformers we automatically check whether the adapter name is already in use or not
2175
+ text_encoder.load_adapter(
2176
+ adapter_name=adapter_name,
2177
+ adapter_state_dict=text_encoder_lora_state_dict,
2178
+ peft_config=lora_config,
2179
+ )
2180
+
2181
+ # scale LoRA layers with `lora_scale`
2182
+ scale_lora_layers(text_encoder, weight=lora_scale)
2183
+
2184
+ text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
2185
+
2186
+ # Offload back.
2187
+ if is_model_cpu_offload:
2188
+ _pipeline.enable_model_cpu_offload()
2189
+ elif is_sequential_cpu_offload:
2190
+ _pipeline.enable_sequential_cpu_offload()
2191
+ # Unsafe code />
2192
+
2193
+ @classmethod
2194
+ def save_lora_weights(
2195
+ cls,
2196
+ save_directory: Union[str, os.PathLike],
2197
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
2198
+ transformer_lora_layers: Dict[str, torch.nn.Module] = None,
2199
+ is_main_process: bool = True,
2200
+ weight_name: str = None,
2201
+ save_function: Callable = None,
2202
+ safe_serialization: bool = True,
2203
+ ):
2204
+ r"""
2205
+ Save the LoRA parameters corresponding to the UNet and text encoder.
2206
+
2207
+ Arguments:
2208
+ save_directory (`str` or `os.PathLike`):
2209
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
2210
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2211
+ State dict of the LoRA layers corresponding to the `unet`.
2212
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
2213
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
2214
+ encoder LoRA state dict because it comes from 🤗 Transformers.
2215
+ is_main_process (`bool`, *optional*, defaults to `True`):
2216
+ Whether the process calling this is the main process or not. Useful during distributed training and you
2217
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
2218
+ process to avoid race conditions.
2219
+ save_function (`Callable`):
2220
+ The function to use to save the state dictionary. Useful during distributed training when you need to
2221
+ replace `torch.save` with another method. Can be configured with the environment variable
2222
+ `DIFFUSERS_SAVE_MODE`.
2223
+ safe_serialization (`bool`, *optional*, defaults to `True`):
2224
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
2225
+ """
2226
+ state_dict = {}
2227
+
2228
+ if not (transformer_lora_layers or text_encoder_lora_layers):
2229
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
2230
+
2231
+ if transformer_lora_layers:
2232
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
2233
+
2234
+ if text_encoder_lora_layers:
2235
+ state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
2236
+
2237
+ # Save the model
2238
+ cls.write_lora_layers(
2239
+ state_dict=state_dict,
2240
+ save_directory=save_directory,
2241
+ is_main_process=is_main_process,
2242
+ weight_name=weight_name,
2243
+ save_function=save_function,
2244
+ safe_serialization=safe_serialization,
2245
+ )
2246
+
2247
+
2248
+ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
2249
+ def __init__(self, *args, **kwargs):
2250
+ deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
2251
+ deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
2252
+ super().__init__(*args, **kwargs)