diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2252 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
from typing import Callable, Dict, List, Optional, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
19
|
+
|
20
|
+
from ..utils import (
|
21
|
+
USE_PEFT_BACKEND,
|
22
|
+
convert_state_dict_to_diffusers,
|
23
|
+
convert_state_dict_to_peft,
|
24
|
+
convert_unet_state_dict_to_peft,
|
25
|
+
deprecate,
|
26
|
+
get_adapter_name,
|
27
|
+
get_peft_kwargs,
|
28
|
+
is_peft_version,
|
29
|
+
is_transformers_available,
|
30
|
+
logging,
|
31
|
+
scale_lora_layers,
|
32
|
+
)
|
33
|
+
from .lora_base import LoraBaseMixin
|
34
|
+
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
35
|
+
|
36
|
+
|
37
|
+
if is_transformers_available():
|
38
|
+
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
39
|
+
|
40
|
+
logger = logging.get_logger(__name__)
|
41
|
+
|
42
|
+
TEXT_ENCODER_NAME = "text_encoder"
|
43
|
+
UNET_NAME = "unet"
|
44
|
+
TRANSFORMER_NAME = "transformer"
|
45
|
+
|
46
|
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
47
|
+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
48
|
+
|
49
|
+
|
50
|
+
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
51
|
+
r"""
|
52
|
+
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
53
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
54
|
+
"""
|
55
|
+
|
56
|
+
_lora_loadable_modules = ["unet", "text_encoder"]
|
57
|
+
unet_name = UNET_NAME
|
58
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
59
|
+
|
60
|
+
def load_lora_weights(
|
61
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
62
|
+
):
|
63
|
+
"""
|
64
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
65
|
+
`self.text_encoder`.
|
66
|
+
|
67
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
68
|
+
|
69
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
70
|
+
loaded.
|
71
|
+
|
72
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
73
|
+
loaded into `self.unet`.
|
74
|
+
|
75
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
76
|
+
dict is loaded into `self.text_encoder`.
|
77
|
+
|
78
|
+
Parameters:
|
79
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
80
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
81
|
+
kwargs (`dict`, *optional*):
|
82
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
83
|
+
adapter_name (`str`, *optional*):
|
84
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
85
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
86
|
+
"""
|
87
|
+
if not USE_PEFT_BACKEND:
|
88
|
+
raise ValueError("PEFT backend is required for this method.")
|
89
|
+
|
90
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
91
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
92
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
93
|
+
|
94
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
95
|
+
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
96
|
+
|
97
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
98
|
+
if not is_correct_format:
|
99
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
100
|
+
|
101
|
+
self.load_lora_into_unet(
|
102
|
+
state_dict,
|
103
|
+
network_alphas=network_alphas,
|
104
|
+
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
105
|
+
adapter_name=adapter_name,
|
106
|
+
_pipeline=self,
|
107
|
+
)
|
108
|
+
self.load_lora_into_text_encoder(
|
109
|
+
state_dict,
|
110
|
+
network_alphas=network_alphas,
|
111
|
+
text_encoder=getattr(self, self.text_encoder_name)
|
112
|
+
if not hasattr(self, "text_encoder")
|
113
|
+
else self.text_encoder,
|
114
|
+
lora_scale=self.lora_scale,
|
115
|
+
adapter_name=adapter_name,
|
116
|
+
_pipeline=self,
|
117
|
+
)
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
@validate_hf_hub_args
|
121
|
+
def lora_state_dict(
|
122
|
+
cls,
|
123
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
124
|
+
**kwargs,
|
125
|
+
):
|
126
|
+
r"""
|
127
|
+
Return state dict for lora weights and the network alphas.
|
128
|
+
|
129
|
+
<Tip warning={true}>
|
130
|
+
|
131
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
132
|
+
|
133
|
+
This function is experimental and might change in the future.
|
134
|
+
|
135
|
+
</Tip>
|
136
|
+
|
137
|
+
Parameters:
|
138
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
139
|
+
Can be either:
|
140
|
+
|
141
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
142
|
+
the Hub.
|
143
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
144
|
+
with [`ModelMixin.save_pretrained`].
|
145
|
+
- A [torch state
|
146
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
147
|
+
|
148
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
149
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
150
|
+
is not used.
|
151
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
152
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
153
|
+
cached versions if they exist.
|
154
|
+
|
155
|
+
proxies (`Dict[str, str]`, *optional*):
|
156
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
157
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
158
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
159
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
160
|
+
won't be downloaded from the Hub.
|
161
|
+
token (`str` or *bool*, *optional*):
|
162
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
163
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
164
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
165
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
166
|
+
allowed by Git.
|
167
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
168
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
169
|
+
weight_name (`str`, *optional*, defaults to None):
|
170
|
+
Name of the serialized state dict file.
|
171
|
+
"""
|
172
|
+
# Load the main state dict first which has the LoRA layers for either of
|
173
|
+
# UNet and text encoder or both.
|
174
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
175
|
+
force_download = kwargs.pop("force_download", False)
|
176
|
+
proxies = kwargs.pop("proxies", None)
|
177
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
178
|
+
token = kwargs.pop("token", None)
|
179
|
+
revision = kwargs.pop("revision", None)
|
180
|
+
subfolder = kwargs.pop("subfolder", None)
|
181
|
+
weight_name = kwargs.pop("weight_name", None)
|
182
|
+
unet_config = kwargs.pop("unet_config", None)
|
183
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
184
|
+
|
185
|
+
allow_pickle = False
|
186
|
+
if use_safetensors is None:
|
187
|
+
use_safetensors = True
|
188
|
+
allow_pickle = True
|
189
|
+
|
190
|
+
user_agent = {
|
191
|
+
"file_type": "attn_procs_weights",
|
192
|
+
"framework": "pytorch",
|
193
|
+
}
|
194
|
+
|
195
|
+
state_dict = cls._fetch_state_dict(
|
196
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
197
|
+
weight_name=weight_name,
|
198
|
+
use_safetensors=use_safetensors,
|
199
|
+
local_files_only=local_files_only,
|
200
|
+
cache_dir=cache_dir,
|
201
|
+
force_download=force_download,
|
202
|
+
proxies=proxies,
|
203
|
+
token=token,
|
204
|
+
revision=revision,
|
205
|
+
subfolder=subfolder,
|
206
|
+
user_agent=user_agent,
|
207
|
+
allow_pickle=allow_pickle,
|
208
|
+
)
|
209
|
+
|
210
|
+
network_alphas = None
|
211
|
+
# TODO: replace it with a method from `state_dict_utils`
|
212
|
+
if all(
|
213
|
+
(
|
214
|
+
k.startswith("lora_te_")
|
215
|
+
or k.startswith("lora_unet_")
|
216
|
+
or k.startswith("lora_te1_")
|
217
|
+
or k.startswith("lora_te2_")
|
218
|
+
)
|
219
|
+
for k in state_dict.keys()
|
220
|
+
):
|
221
|
+
# Map SDXL blocks correctly.
|
222
|
+
if unet_config is not None:
|
223
|
+
# use unet config to remap block numbers
|
224
|
+
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
225
|
+
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
226
|
+
|
227
|
+
return state_dict, network_alphas
|
228
|
+
|
229
|
+
@classmethod
|
230
|
+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
231
|
+
"""
|
232
|
+
This will load the LoRA layers specified in `state_dict` into `unet`.
|
233
|
+
|
234
|
+
Parameters:
|
235
|
+
state_dict (`dict`):
|
236
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
237
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
238
|
+
encoder lora layers.
|
239
|
+
network_alphas (`Dict[str, float]`):
|
240
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
241
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
242
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
243
|
+
unet (`UNet2DConditionModel`):
|
244
|
+
The UNet model to load the LoRA layers into.
|
245
|
+
adapter_name (`str`, *optional*):
|
246
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
247
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
248
|
+
"""
|
249
|
+
if not USE_PEFT_BACKEND:
|
250
|
+
raise ValueError("PEFT backend is required for this method.")
|
251
|
+
|
252
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
253
|
+
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
254
|
+
# their prefixes.
|
255
|
+
keys = list(state_dict.keys())
|
256
|
+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
257
|
+
if not only_text_encoder:
|
258
|
+
# Load the layers corresponding to UNet.
|
259
|
+
logger.info(f"Loading {cls.unet_name}.")
|
260
|
+
unet.load_attn_procs(
|
261
|
+
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
262
|
+
)
|
263
|
+
|
264
|
+
@classmethod
|
265
|
+
def load_lora_into_text_encoder(
|
266
|
+
cls,
|
267
|
+
state_dict,
|
268
|
+
network_alphas,
|
269
|
+
text_encoder,
|
270
|
+
prefix=None,
|
271
|
+
lora_scale=1.0,
|
272
|
+
adapter_name=None,
|
273
|
+
_pipeline=None,
|
274
|
+
):
|
275
|
+
"""
|
276
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
277
|
+
|
278
|
+
Parameters:
|
279
|
+
state_dict (`dict`):
|
280
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
281
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
282
|
+
network_alphas (`Dict[str, float]`):
|
283
|
+
See `LoRALinearLayer` for more details.
|
284
|
+
text_encoder (`CLIPTextModel`):
|
285
|
+
The text encoder model to load the LoRA layers into.
|
286
|
+
prefix (`str`):
|
287
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
288
|
+
lora_scale (`float`):
|
289
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
290
|
+
lora layer.
|
291
|
+
adapter_name (`str`, *optional*):
|
292
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
293
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
294
|
+
"""
|
295
|
+
if not USE_PEFT_BACKEND:
|
296
|
+
raise ValueError("PEFT backend is required for this method.")
|
297
|
+
|
298
|
+
from peft import LoraConfig
|
299
|
+
|
300
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
301
|
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
302
|
+
# their prefixes.
|
303
|
+
keys = list(state_dict.keys())
|
304
|
+
prefix = cls.text_encoder_name if prefix is None else prefix
|
305
|
+
|
306
|
+
# Safe prefix to check with.
|
307
|
+
if any(cls.text_encoder_name in key for key in keys):
|
308
|
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
309
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
310
|
+
text_encoder_lora_state_dict = {
|
311
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
312
|
+
}
|
313
|
+
|
314
|
+
if len(text_encoder_lora_state_dict) > 0:
|
315
|
+
logger.info(f"Loading {prefix}.")
|
316
|
+
rank = {}
|
317
|
+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
318
|
+
|
319
|
+
# convert state dict
|
320
|
+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
321
|
+
|
322
|
+
for name, _ in text_encoder_attn_modules(text_encoder):
|
323
|
+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
324
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
325
|
+
if rank_key not in text_encoder_lora_state_dict:
|
326
|
+
continue
|
327
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
328
|
+
|
329
|
+
for name, _ in text_encoder_mlp_modules(text_encoder):
|
330
|
+
for module in ("fc1", "fc2"):
|
331
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
332
|
+
if rank_key not in text_encoder_lora_state_dict:
|
333
|
+
continue
|
334
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
335
|
+
|
336
|
+
if network_alphas is not None:
|
337
|
+
alpha_keys = [
|
338
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
339
|
+
]
|
340
|
+
network_alphas = {
|
341
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
342
|
+
}
|
343
|
+
|
344
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
345
|
+
if "use_dora" in lora_config_kwargs:
|
346
|
+
if lora_config_kwargs["use_dora"]:
|
347
|
+
if is_peft_version("<", "0.9.0"):
|
348
|
+
raise ValueError(
|
349
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
350
|
+
)
|
351
|
+
else:
|
352
|
+
if is_peft_version("<", "0.9.0"):
|
353
|
+
lora_config_kwargs.pop("use_dora")
|
354
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
355
|
+
|
356
|
+
# adapter_name
|
357
|
+
if adapter_name is None:
|
358
|
+
adapter_name = get_adapter_name(text_encoder)
|
359
|
+
|
360
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
361
|
+
|
362
|
+
# inject LoRA layers and load the state dict
|
363
|
+
# in transformers we automatically check whether the adapter name is already in use or not
|
364
|
+
text_encoder.load_adapter(
|
365
|
+
adapter_name=adapter_name,
|
366
|
+
adapter_state_dict=text_encoder_lora_state_dict,
|
367
|
+
peft_config=lora_config,
|
368
|
+
)
|
369
|
+
|
370
|
+
# scale LoRA layers with `lora_scale`
|
371
|
+
scale_lora_layers(text_encoder, weight=lora_scale)
|
372
|
+
|
373
|
+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
374
|
+
|
375
|
+
# Offload back.
|
376
|
+
if is_model_cpu_offload:
|
377
|
+
_pipeline.enable_model_cpu_offload()
|
378
|
+
elif is_sequential_cpu_offload:
|
379
|
+
_pipeline.enable_sequential_cpu_offload()
|
380
|
+
# Unsafe code />
|
381
|
+
|
382
|
+
@classmethod
|
383
|
+
def save_lora_weights(
|
384
|
+
cls,
|
385
|
+
save_directory: Union[str, os.PathLike],
|
386
|
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
387
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
388
|
+
is_main_process: bool = True,
|
389
|
+
weight_name: str = None,
|
390
|
+
save_function: Callable = None,
|
391
|
+
safe_serialization: bool = True,
|
392
|
+
):
|
393
|
+
r"""
|
394
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
395
|
+
|
396
|
+
Arguments:
|
397
|
+
save_directory (`str` or `os.PathLike`):
|
398
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
399
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
400
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
401
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
402
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
403
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
404
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
405
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
406
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
407
|
+
process to avoid race conditions.
|
408
|
+
save_function (`Callable`):
|
409
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
410
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
411
|
+
`DIFFUSERS_SAVE_MODE`.
|
412
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
413
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
414
|
+
"""
|
415
|
+
state_dict = {}
|
416
|
+
|
417
|
+
if not (unet_lora_layers or text_encoder_lora_layers):
|
418
|
+
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
419
|
+
|
420
|
+
if unet_lora_layers:
|
421
|
+
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
422
|
+
|
423
|
+
if text_encoder_lora_layers:
|
424
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
425
|
+
|
426
|
+
# Save the model
|
427
|
+
cls.write_lora_layers(
|
428
|
+
state_dict=state_dict,
|
429
|
+
save_directory=save_directory,
|
430
|
+
is_main_process=is_main_process,
|
431
|
+
weight_name=weight_name,
|
432
|
+
save_function=save_function,
|
433
|
+
safe_serialization=safe_serialization,
|
434
|
+
)
|
435
|
+
|
436
|
+
def fuse_lora(
|
437
|
+
self,
|
438
|
+
components: List[str] = ["unet", "text_encoder"],
|
439
|
+
lora_scale: float = 1.0,
|
440
|
+
safe_fusing: bool = False,
|
441
|
+
adapter_names: Optional[List[str]] = None,
|
442
|
+
**kwargs,
|
443
|
+
):
|
444
|
+
r"""
|
445
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
446
|
+
|
447
|
+
<Tip warning={true}>
|
448
|
+
|
449
|
+
This is an experimental API.
|
450
|
+
|
451
|
+
</Tip>
|
452
|
+
|
453
|
+
Args:
|
454
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
455
|
+
lora_scale (`float`, defaults to 1.0):
|
456
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
457
|
+
safe_fusing (`bool`, defaults to `False`):
|
458
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
459
|
+
adapter_names (`List[str]`, *optional*):
|
460
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
461
|
+
|
462
|
+
Example:
|
463
|
+
|
464
|
+
```py
|
465
|
+
from diffusers import DiffusionPipeline
|
466
|
+
import torch
|
467
|
+
|
468
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
469
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
470
|
+
).to("cuda")
|
471
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
472
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
473
|
+
```
|
474
|
+
"""
|
475
|
+
super().fuse_lora(
|
476
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
477
|
+
)
|
478
|
+
|
479
|
+
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
480
|
+
r"""
|
481
|
+
Reverses the effect of
|
482
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
483
|
+
|
484
|
+
<Tip warning={true}>
|
485
|
+
|
486
|
+
This is an experimental API.
|
487
|
+
|
488
|
+
</Tip>
|
489
|
+
|
490
|
+
Args:
|
491
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
492
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
493
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
494
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
495
|
+
LoRA parameters then it won't have any effect.
|
496
|
+
"""
|
497
|
+
super().unfuse_lora(components=components)
|
498
|
+
|
499
|
+
|
500
|
+
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
501
|
+
r"""
|
502
|
+
Load LoRA layers into Stable Diffusion XL [`UNet2DConditionModel`],
|
503
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
|
504
|
+
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
|
505
|
+
"""
|
506
|
+
|
507
|
+
_lora_loadable_modules = ["unet", "text_encoder", "text_encoder_2"]
|
508
|
+
unet_name = UNET_NAME
|
509
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
510
|
+
|
511
|
+
def load_lora_weights(
|
512
|
+
self,
|
513
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
514
|
+
adapter_name: Optional[str] = None,
|
515
|
+
**kwargs,
|
516
|
+
):
|
517
|
+
"""
|
518
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
519
|
+
`self.text_encoder`.
|
520
|
+
|
521
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
522
|
+
|
523
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
524
|
+
loaded.
|
525
|
+
|
526
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
527
|
+
loaded into `self.unet`.
|
528
|
+
|
529
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
530
|
+
dict is loaded into `self.text_encoder`.
|
531
|
+
|
532
|
+
Parameters:
|
533
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
534
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
535
|
+
adapter_name (`str`, *optional*):
|
536
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
537
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
538
|
+
kwargs (`dict`, *optional*):
|
539
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
540
|
+
"""
|
541
|
+
if not USE_PEFT_BACKEND:
|
542
|
+
raise ValueError("PEFT backend is required for this method.")
|
543
|
+
|
544
|
+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
545
|
+
# it here explicitly to be able to tell that it's coming from an SDXL
|
546
|
+
# pipeline.
|
547
|
+
|
548
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
549
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
550
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
551
|
+
|
552
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
553
|
+
state_dict, network_alphas = self.lora_state_dict(
|
554
|
+
pretrained_model_name_or_path_or_dict,
|
555
|
+
unet_config=self.unet.config,
|
556
|
+
**kwargs,
|
557
|
+
)
|
558
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
559
|
+
if not is_correct_format:
|
560
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
561
|
+
|
562
|
+
self.load_lora_into_unet(
|
563
|
+
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
564
|
+
)
|
565
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
566
|
+
if len(text_encoder_state_dict) > 0:
|
567
|
+
self.load_lora_into_text_encoder(
|
568
|
+
text_encoder_state_dict,
|
569
|
+
network_alphas=network_alphas,
|
570
|
+
text_encoder=self.text_encoder,
|
571
|
+
prefix="text_encoder",
|
572
|
+
lora_scale=self.lora_scale,
|
573
|
+
adapter_name=adapter_name,
|
574
|
+
_pipeline=self,
|
575
|
+
)
|
576
|
+
|
577
|
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
578
|
+
if len(text_encoder_2_state_dict) > 0:
|
579
|
+
self.load_lora_into_text_encoder(
|
580
|
+
text_encoder_2_state_dict,
|
581
|
+
network_alphas=network_alphas,
|
582
|
+
text_encoder=self.text_encoder_2,
|
583
|
+
prefix="text_encoder_2",
|
584
|
+
lora_scale=self.lora_scale,
|
585
|
+
adapter_name=adapter_name,
|
586
|
+
_pipeline=self,
|
587
|
+
)
|
588
|
+
|
589
|
+
@classmethod
|
590
|
+
@validate_hf_hub_args
|
591
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict
|
592
|
+
def lora_state_dict(
|
593
|
+
cls,
|
594
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
595
|
+
**kwargs,
|
596
|
+
):
|
597
|
+
r"""
|
598
|
+
Return state dict for lora weights and the network alphas.
|
599
|
+
|
600
|
+
<Tip warning={true}>
|
601
|
+
|
602
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
603
|
+
|
604
|
+
This function is experimental and might change in the future.
|
605
|
+
|
606
|
+
</Tip>
|
607
|
+
|
608
|
+
Parameters:
|
609
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
610
|
+
Can be either:
|
611
|
+
|
612
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
613
|
+
the Hub.
|
614
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
615
|
+
with [`ModelMixin.save_pretrained`].
|
616
|
+
- A [torch state
|
617
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
618
|
+
|
619
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
620
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
621
|
+
is not used.
|
622
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
623
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
624
|
+
cached versions if they exist.
|
625
|
+
|
626
|
+
proxies (`Dict[str, str]`, *optional*):
|
627
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
628
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
629
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
630
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
631
|
+
won't be downloaded from the Hub.
|
632
|
+
token (`str` or *bool*, *optional*):
|
633
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
634
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
635
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
636
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
637
|
+
allowed by Git.
|
638
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
639
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
640
|
+
weight_name (`str`, *optional*, defaults to None):
|
641
|
+
Name of the serialized state dict file.
|
642
|
+
"""
|
643
|
+
# Load the main state dict first which has the LoRA layers for either of
|
644
|
+
# UNet and text encoder or both.
|
645
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
646
|
+
force_download = kwargs.pop("force_download", False)
|
647
|
+
proxies = kwargs.pop("proxies", None)
|
648
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
649
|
+
token = kwargs.pop("token", None)
|
650
|
+
revision = kwargs.pop("revision", None)
|
651
|
+
subfolder = kwargs.pop("subfolder", None)
|
652
|
+
weight_name = kwargs.pop("weight_name", None)
|
653
|
+
unet_config = kwargs.pop("unet_config", None)
|
654
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
655
|
+
|
656
|
+
allow_pickle = False
|
657
|
+
if use_safetensors is None:
|
658
|
+
use_safetensors = True
|
659
|
+
allow_pickle = True
|
660
|
+
|
661
|
+
user_agent = {
|
662
|
+
"file_type": "attn_procs_weights",
|
663
|
+
"framework": "pytorch",
|
664
|
+
}
|
665
|
+
|
666
|
+
state_dict = cls._fetch_state_dict(
|
667
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
668
|
+
weight_name=weight_name,
|
669
|
+
use_safetensors=use_safetensors,
|
670
|
+
local_files_only=local_files_only,
|
671
|
+
cache_dir=cache_dir,
|
672
|
+
force_download=force_download,
|
673
|
+
proxies=proxies,
|
674
|
+
token=token,
|
675
|
+
revision=revision,
|
676
|
+
subfolder=subfolder,
|
677
|
+
user_agent=user_agent,
|
678
|
+
allow_pickle=allow_pickle,
|
679
|
+
)
|
680
|
+
|
681
|
+
network_alphas = None
|
682
|
+
# TODO: replace it with a method from `state_dict_utils`
|
683
|
+
if all(
|
684
|
+
(
|
685
|
+
k.startswith("lora_te_")
|
686
|
+
or k.startswith("lora_unet_")
|
687
|
+
or k.startswith("lora_te1_")
|
688
|
+
or k.startswith("lora_te2_")
|
689
|
+
)
|
690
|
+
for k in state_dict.keys()
|
691
|
+
):
|
692
|
+
# Map SDXL blocks correctly.
|
693
|
+
if unet_config is not None:
|
694
|
+
# use unet config to remap block numbers
|
695
|
+
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
696
|
+
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
697
|
+
|
698
|
+
return state_dict, network_alphas
|
699
|
+
|
700
|
+
@classmethod
|
701
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
702
|
+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
703
|
+
"""
|
704
|
+
This will load the LoRA layers specified in `state_dict` into `unet`.
|
705
|
+
|
706
|
+
Parameters:
|
707
|
+
state_dict (`dict`):
|
708
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
709
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
710
|
+
encoder lora layers.
|
711
|
+
network_alphas (`Dict[str, float]`):
|
712
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
713
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
714
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
715
|
+
unet (`UNet2DConditionModel`):
|
716
|
+
The UNet model to load the LoRA layers into.
|
717
|
+
adapter_name (`str`, *optional*):
|
718
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
719
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
720
|
+
"""
|
721
|
+
if not USE_PEFT_BACKEND:
|
722
|
+
raise ValueError("PEFT backend is required for this method.")
|
723
|
+
|
724
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
725
|
+
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
726
|
+
# their prefixes.
|
727
|
+
keys = list(state_dict.keys())
|
728
|
+
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
729
|
+
if not only_text_encoder:
|
730
|
+
# Load the layers corresponding to UNet.
|
731
|
+
logger.info(f"Loading {cls.unet_name}.")
|
732
|
+
unet.load_attn_procs(
|
733
|
+
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
734
|
+
)
|
735
|
+
|
736
|
+
@classmethod
|
737
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
738
|
+
def load_lora_into_text_encoder(
|
739
|
+
cls,
|
740
|
+
state_dict,
|
741
|
+
network_alphas,
|
742
|
+
text_encoder,
|
743
|
+
prefix=None,
|
744
|
+
lora_scale=1.0,
|
745
|
+
adapter_name=None,
|
746
|
+
_pipeline=None,
|
747
|
+
):
|
748
|
+
"""
|
749
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
750
|
+
|
751
|
+
Parameters:
|
752
|
+
state_dict (`dict`):
|
753
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
754
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
755
|
+
network_alphas (`Dict[str, float]`):
|
756
|
+
See `LoRALinearLayer` for more details.
|
757
|
+
text_encoder (`CLIPTextModel`):
|
758
|
+
The text encoder model to load the LoRA layers into.
|
759
|
+
prefix (`str`):
|
760
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
761
|
+
lora_scale (`float`):
|
762
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
763
|
+
lora layer.
|
764
|
+
adapter_name (`str`, *optional*):
|
765
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
766
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
767
|
+
"""
|
768
|
+
if not USE_PEFT_BACKEND:
|
769
|
+
raise ValueError("PEFT backend is required for this method.")
|
770
|
+
|
771
|
+
from peft import LoraConfig
|
772
|
+
|
773
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
774
|
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
775
|
+
# their prefixes.
|
776
|
+
keys = list(state_dict.keys())
|
777
|
+
prefix = cls.text_encoder_name if prefix is None else prefix
|
778
|
+
|
779
|
+
# Safe prefix to check with.
|
780
|
+
if any(cls.text_encoder_name in key for key in keys):
|
781
|
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
782
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
783
|
+
text_encoder_lora_state_dict = {
|
784
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
785
|
+
}
|
786
|
+
|
787
|
+
if len(text_encoder_lora_state_dict) > 0:
|
788
|
+
logger.info(f"Loading {prefix}.")
|
789
|
+
rank = {}
|
790
|
+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
791
|
+
|
792
|
+
# convert state dict
|
793
|
+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
794
|
+
|
795
|
+
for name, _ in text_encoder_attn_modules(text_encoder):
|
796
|
+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
797
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
798
|
+
if rank_key not in text_encoder_lora_state_dict:
|
799
|
+
continue
|
800
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
801
|
+
|
802
|
+
for name, _ in text_encoder_mlp_modules(text_encoder):
|
803
|
+
for module in ("fc1", "fc2"):
|
804
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
805
|
+
if rank_key not in text_encoder_lora_state_dict:
|
806
|
+
continue
|
807
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
808
|
+
|
809
|
+
if network_alphas is not None:
|
810
|
+
alpha_keys = [
|
811
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
812
|
+
]
|
813
|
+
network_alphas = {
|
814
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
815
|
+
}
|
816
|
+
|
817
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
818
|
+
if "use_dora" in lora_config_kwargs:
|
819
|
+
if lora_config_kwargs["use_dora"]:
|
820
|
+
if is_peft_version("<", "0.9.0"):
|
821
|
+
raise ValueError(
|
822
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
823
|
+
)
|
824
|
+
else:
|
825
|
+
if is_peft_version("<", "0.9.0"):
|
826
|
+
lora_config_kwargs.pop("use_dora")
|
827
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
828
|
+
|
829
|
+
# adapter_name
|
830
|
+
if adapter_name is None:
|
831
|
+
adapter_name = get_adapter_name(text_encoder)
|
832
|
+
|
833
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
834
|
+
|
835
|
+
# inject LoRA layers and load the state dict
|
836
|
+
# in transformers we automatically check whether the adapter name is already in use or not
|
837
|
+
text_encoder.load_adapter(
|
838
|
+
adapter_name=adapter_name,
|
839
|
+
adapter_state_dict=text_encoder_lora_state_dict,
|
840
|
+
peft_config=lora_config,
|
841
|
+
)
|
842
|
+
|
843
|
+
# scale LoRA layers with `lora_scale`
|
844
|
+
scale_lora_layers(text_encoder, weight=lora_scale)
|
845
|
+
|
846
|
+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
847
|
+
|
848
|
+
# Offload back.
|
849
|
+
if is_model_cpu_offload:
|
850
|
+
_pipeline.enable_model_cpu_offload()
|
851
|
+
elif is_sequential_cpu_offload:
|
852
|
+
_pipeline.enable_sequential_cpu_offload()
|
853
|
+
# Unsafe code />
|
854
|
+
|
855
|
+
@classmethod
|
856
|
+
def save_lora_weights(
|
857
|
+
cls,
|
858
|
+
save_directory: Union[str, os.PathLike],
|
859
|
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
860
|
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
861
|
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
862
|
+
is_main_process: bool = True,
|
863
|
+
weight_name: str = None,
|
864
|
+
save_function: Callable = None,
|
865
|
+
safe_serialization: bool = True,
|
866
|
+
):
|
867
|
+
r"""
|
868
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
869
|
+
|
870
|
+
Arguments:
|
871
|
+
save_directory (`str` or `os.PathLike`):
|
872
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
873
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
874
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
875
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
876
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
877
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
878
|
+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
879
|
+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
880
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
881
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
882
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
883
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
884
|
+
process to avoid race conditions.
|
885
|
+
save_function (`Callable`):
|
886
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
887
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
888
|
+
`DIFFUSERS_SAVE_MODE`.
|
889
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
890
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
891
|
+
"""
|
892
|
+
state_dict = {}
|
893
|
+
|
894
|
+
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
895
|
+
raise ValueError(
|
896
|
+
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
897
|
+
)
|
898
|
+
|
899
|
+
if unet_lora_layers:
|
900
|
+
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
901
|
+
|
902
|
+
if text_encoder_lora_layers:
|
903
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
904
|
+
|
905
|
+
if text_encoder_2_lora_layers:
|
906
|
+
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
907
|
+
|
908
|
+
cls.write_lora_layers(
|
909
|
+
state_dict=state_dict,
|
910
|
+
save_directory=save_directory,
|
911
|
+
is_main_process=is_main_process,
|
912
|
+
weight_name=weight_name,
|
913
|
+
save_function=save_function,
|
914
|
+
safe_serialization=safe_serialization,
|
915
|
+
)
|
916
|
+
|
917
|
+
def fuse_lora(
|
918
|
+
self,
|
919
|
+
components: List[str] = ["unet", "text_encoder", "text_encoder_2"],
|
920
|
+
lora_scale: float = 1.0,
|
921
|
+
safe_fusing: bool = False,
|
922
|
+
adapter_names: Optional[List[str]] = None,
|
923
|
+
**kwargs,
|
924
|
+
):
|
925
|
+
r"""
|
926
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
927
|
+
|
928
|
+
<Tip warning={true}>
|
929
|
+
|
930
|
+
This is an experimental API.
|
931
|
+
|
932
|
+
</Tip>
|
933
|
+
|
934
|
+
Args:
|
935
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
936
|
+
lora_scale (`float`, defaults to 1.0):
|
937
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
938
|
+
safe_fusing (`bool`, defaults to `False`):
|
939
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
940
|
+
adapter_names (`List[str]`, *optional*):
|
941
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
942
|
+
|
943
|
+
Example:
|
944
|
+
|
945
|
+
```py
|
946
|
+
from diffusers import DiffusionPipeline
|
947
|
+
import torch
|
948
|
+
|
949
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
950
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
951
|
+
).to("cuda")
|
952
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
953
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
954
|
+
```
|
955
|
+
"""
|
956
|
+
super().fuse_lora(
|
957
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
958
|
+
)
|
959
|
+
|
960
|
+
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
|
961
|
+
r"""
|
962
|
+
Reverses the effect of
|
963
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
964
|
+
|
965
|
+
<Tip warning={true}>
|
966
|
+
|
967
|
+
This is an experimental API.
|
968
|
+
|
969
|
+
</Tip>
|
970
|
+
|
971
|
+
Args:
|
972
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
973
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
974
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
975
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
976
|
+
LoRA parameters then it won't have any effect.
|
977
|
+
"""
|
978
|
+
super().unfuse_lora(components=components)
|
979
|
+
|
980
|
+
|
981
|
+
class SD3LoraLoaderMixin(LoraBaseMixin):
|
982
|
+
r"""
|
983
|
+
Load LoRA layers into [`SD3Transformer2DModel`],
|
984
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), and
|
985
|
+
[`CLIPTextModelWithProjection`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection).
|
986
|
+
|
987
|
+
Specific to [`StableDiffusion3Pipeline`].
|
988
|
+
"""
|
989
|
+
|
990
|
+
_lora_loadable_modules = ["transformer", "text_encoder", "text_encoder_2"]
|
991
|
+
transformer_name = TRANSFORMER_NAME
|
992
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
993
|
+
|
994
|
+
@classmethod
|
995
|
+
@validate_hf_hub_args
|
996
|
+
def lora_state_dict(
|
997
|
+
cls,
|
998
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
999
|
+
**kwargs,
|
1000
|
+
):
|
1001
|
+
r"""
|
1002
|
+
Return state dict for lora weights and the network alphas.
|
1003
|
+
|
1004
|
+
<Tip warning={true}>
|
1005
|
+
|
1006
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
1007
|
+
|
1008
|
+
This function is experimental and might change in the future.
|
1009
|
+
|
1010
|
+
</Tip>
|
1011
|
+
|
1012
|
+
Parameters:
|
1013
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1014
|
+
Can be either:
|
1015
|
+
|
1016
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
1017
|
+
the Hub.
|
1018
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
1019
|
+
with [`ModelMixin.save_pretrained`].
|
1020
|
+
- A [torch state
|
1021
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1022
|
+
|
1023
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1024
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1025
|
+
is not used.
|
1026
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1027
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1028
|
+
cached versions if they exist.
|
1029
|
+
|
1030
|
+
proxies (`Dict[str, str]`, *optional*):
|
1031
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1032
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1033
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1034
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
1035
|
+
won't be downloaded from the Hub.
|
1036
|
+
token (`str` or *bool*, *optional*):
|
1037
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1038
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1039
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1040
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1041
|
+
allowed by Git.
|
1042
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
1043
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1044
|
+
|
1045
|
+
"""
|
1046
|
+
# Load the main state dict first which has the LoRA layers for either of
|
1047
|
+
# transformer and text encoder or both.
|
1048
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1049
|
+
force_download = kwargs.pop("force_download", False)
|
1050
|
+
proxies = kwargs.pop("proxies", None)
|
1051
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1052
|
+
token = kwargs.pop("token", None)
|
1053
|
+
revision = kwargs.pop("revision", None)
|
1054
|
+
subfolder = kwargs.pop("subfolder", None)
|
1055
|
+
weight_name = kwargs.pop("weight_name", None)
|
1056
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1057
|
+
|
1058
|
+
allow_pickle = False
|
1059
|
+
if use_safetensors is None:
|
1060
|
+
use_safetensors = True
|
1061
|
+
allow_pickle = True
|
1062
|
+
|
1063
|
+
user_agent = {
|
1064
|
+
"file_type": "attn_procs_weights",
|
1065
|
+
"framework": "pytorch",
|
1066
|
+
}
|
1067
|
+
|
1068
|
+
state_dict = cls._fetch_state_dict(
|
1069
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1070
|
+
weight_name=weight_name,
|
1071
|
+
use_safetensors=use_safetensors,
|
1072
|
+
local_files_only=local_files_only,
|
1073
|
+
cache_dir=cache_dir,
|
1074
|
+
force_download=force_download,
|
1075
|
+
proxies=proxies,
|
1076
|
+
token=token,
|
1077
|
+
revision=revision,
|
1078
|
+
subfolder=subfolder,
|
1079
|
+
user_agent=user_agent,
|
1080
|
+
allow_pickle=allow_pickle,
|
1081
|
+
)
|
1082
|
+
|
1083
|
+
return state_dict
|
1084
|
+
|
1085
|
+
def load_lora_weights(
|
1086
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
1087
|
+
):
|
1088
|
+
"""
|
1089
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
1090
|
+
`self.text_encoder`.
|
1091
|
+
|
1092
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
1093
|
+
|
1094
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
1095
|
+
loaded.
|
1096
|
+
|
1097
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
1098
|
+
dict is loaded into `self.transformer`.
|
1099
|
+
|
1100
|
+
Parameters:
|
1101
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1102
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1103
|
+
kwargs (`dict`, *optional*):
|
1104
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1105
|
+
adapter_name (`str`, *optional*):
|
1106
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1107
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1108
|
+
"""
|
1109
|
+
if not USE_PEFT_BACKEND:
|
1110
|
+
raise ValueError("PEFT backend is required for this method.")
|
1111
|
+
|
1112
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1113
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1114
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1115
|
+
|
1116
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1117
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1118
|
+
|
1119
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
1120
|
+
if not is_correct_format:
|
1121
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1122
|
+
|
1123
|
+
self.load_lora_into_transformer(
|
1124
|
+
state_dict,
|
1125
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1126
|
+
adapter_name=adapter_name,
|
1127
|
+
_pipeline=self,
|
1128
|
+
)
|
1129
|
+
|
1130
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1131
|
+
if len(text_encoder_state_dict) > 0:
|
1132
|
+
self.load_lora_into_text_encoder(
|
1133
|
+
text_encoder_state_dict,
|
1134
|
+
network_alphas=None,
|
1135
|
+
text_encoder=self.text_encoder,
|
1136
|
+
prefix="text_encoder",
|
1137
|
+
lora_scale=self.lora_scale,
|
1138
|
+
adapter_name=adapter_name,
|
1139
|
+
_pipeline=self,
|
1140
|
+
)
|
1141
|
+
|
1142
|
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
1143
|
+
if len(text_encoder_2_state_dict) > 0:
|
1144
|
+
self.load_lora_into_text_encoder(
|
1145
|
+
text_encoder_2_state_dict,
|
1146
|
+
network_alphas=None,
|
1147
|
+
text_encoder=self.text_encoder_2,
|
1148
|
+
prefix="text_encoder_2",
|
1149
|
+
lora_scale=self.lora_scale,
|
1150
|
+
adapter_name=adapter_name,
|
1151
|
+
_pipeline=self,
|
1152
|
+
)
|
1153
|
+
|
1154
|
+
@classmethod
|
1155
|
+
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
1156
|
+
"""
|
1157
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1158
|
+
|
1159
|
+
Parameters:
|
1160
|
+
state_dict (`dict`):
|
1161
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1162
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1163
|
+
encoder lora layers.
|
1164
|
+
transformer (`SD3Transformer2DModel`):
|
1165
|
+
The Transformer model to load the LoRA layers into.
|
1166
|
+
adapter_name (`str`, *optional*):
|
1167
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1168
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1169
|
+
"""
|
1170
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
1171
|
+
|
1172
|
+
keys = list(state_dict.keys())
|
1173
|
+
|
1174
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
1175
|
+
state_dict = {
|
1176
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
1177
|
+
}
|
1178
|
+
|
1179
|
+
if len(state_dict.keys()) > 0:
|
1180
|
+
# check with first key if is not in peft format
|
1181
|
+
first_key = next(iter(state_dict.keys()))
|
1182
|
+
if "lora_A" not in first_key:
|
1183
|
+
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
1184
|
+
|
1185
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
1186
|
+
raise ValueError(
|
1187
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
1188
|
+
)
|
1189
|
+
|
1190
|
+
rank = {}
|
1191
|
+
for key, val in state_dict.items():
|
1192
|
+
if "lora_B" in key:
|
1193
|
+
rank[key] = val.shape[1]
|
1194
|
+
|
1195
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
1196
|
+
if "use_dora" in lora_config_kwargs:
|
1197
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1198
|
+
raise ValueError(
|
1199
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1200
|
+
)
|
1201
|
+
else:
|
1202
|
+
lora_config_kwargs.pop("use_dora")
|
1203
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
1204
|
+
|
1205
|
+
# adapter_name
|
1206
|
+
if adapter_name is None:
|
1207
|
+
adapter_name = get_adapter_name(transformer)
|
1208
|
+
|
1209
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1210
|
+
# otherwise loading LoRA weights will lead to an error
|
1211
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1212
|
+
|
1213
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
1214
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
1215
|
+
|
1216
|
+
if incompatible_keys is not None:
|
1217
|
+
# check only for unexpected keys
|
1218
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1219
|
+
if unexpected_keys:
|
1220
|
+
logger.warning(
|
1221
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1222
|
+
f" {unexpected_keys}. "
|
1223
|
+
)
|
1224
|
+
|
1225
|
+
# Offload back.
|
1226
|
+
if is_model_cpu_offload:
|
1227
|
+
_pipeline.enable_model_cpu_offload()
|
1228
|
+
elif is_sequential_cpu_offload:
|
1229
|
+
_pipeline.enable_sequential_cpu_offload()
|
1230
|
+
# Unsafe code />
|
1231
|
+
|
1232
|
+
@classmethod
|
1233
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
1234
|
+
def load_lora_into_text_encoder(
|
1235
|
+
cls,
|
1236
|
+
state_dict,
|
1237
|
+
network_alphas,
|
1238
|
+
text_encoder,
|
1239
|
+
prefix=None,
|
1240
|
+
lora_scale=1.0,
|
1241
|
+
adapter_name=None,
|
1242
|
+
_pipeline=None,
|
1243
|
+
):
|
1244
|
+
"""
|
1245
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
1246
|
+
|
1247
|
+
Parameters:
|
1248
|
+
state_dict (`dict`):
|
1249
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1250
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
1251
|
+
network_alphas (`Dict[str, float]`):
|
1252
|
+
See `LoRALinearLayer` for more details.
|
1253
|
+
text_encoder (`CLIPTextModel`):
|
1254
|
+
The text encoder model to load the LoRA layers into.
|
1255
|
+
prefix (`str`):
|
1256
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
1257
|
+
lora_scale (`float`):
|
1258
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
1259
|
+
lora layer.
|
1260
|
+
adapter_name (`str`, *optional*):
|
1261
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1262
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1263
|
+
"""
|
1264
|
+
if not USE_PEFT_BACKEND:
|
1265
|
+
raise ValueError("PEFT backend is required for this method.")
|
1266
|
+
|
1267
|
+
from peft import LoraConfig
|
1268
|
+
|
1269
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
1270
|
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
1271
|
+
# their prefixes.
|
1272
|
+
keys = list(state_dict.keys())
|
1273
|
+
prefix = cls.text_encoder_name if prefix is None else prefix
|
1274
|
+
|
1275
|
+
# Safe prefix to check with.
|
1276
|
+
if any(cls.text_encoder_name in key for key in keys):
|
1277
|
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
1278
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1279
|
+
text_encoder_lora_state_dict = {
|
1280
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
1281
|
+
}
|
1282
|
+
|
1283
|
+
if len(text_encoder_lora_state_dict) > 0:
|
1284
|
+
logger.info(f"Loading {prefix}.")
|
1285
|
+
rank = {}
|
1286
|
+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
1287
|
+
|
1288
|
+
# convert state dict
|
1289
|
+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
1290
|
+
|
1291
|
+
for name, _ in text_encoder_attn_modules(text_encoder):
|
1292
|
+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
1293
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
1294
|
+
if rank_key not in text_encoder_lora_state_dict:
|
1295
|
+
continue
|
1296
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1297
|
+
|
1298
|
+
for name, _ in text_encoder_mlp_modules(text_encoder):
|
1299
|
+
for module in ("fc1", "fc2"):
|
1300
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
1301
|
+
if rank_key not in text_encoder_lora_state_dict:
|
1302
|
+
continue
|
1303
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1304
|
+
|
1305
|
+
if network_alphas is not None:
|
1306
|
+
alpha_keys = [
|
1307
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
1308
|
+
]
|
1309
|
+
network_alphas = {
|
1310
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
1311
|
+
}
|
1312
|
+
|
1313
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
1314
|
+
if "use_dora" in lora_config_kwargs:
|
1315
|
+
if lora_config_kwargs["use_dora"]:
|
1316
|
+
if is_peft_version("<", "0.9.0"):
|
1317
|
+
raise ValueError(
|
1318
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1319
|
+
)
|
1320
|
+
else:
|
1321
|
+
if is_peft_version("<", "0.9.0"):
|
1322
|
+
lora_config_kwargs.pop("use_dora")
|
1323
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
1324
|
+
|
1325
|
+
# adapter_name
|
1326
|
+
if adapter_name is None:
|
1327
|
+
adapter_name = get_adapter_name(text_encoder)
|
1328
|
+
|
1329
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1330
|
+
|
1331
|
+
# inject LoRA layers and load the state dict
|
1332
|
+
# in transformers we automatically check whether the adapter name is already in use or not
|
1333
|
+
text_encoder.load_adapter(
|
1334
|
+
adapter_name=adapter_name,
|
1335
|
+
adapter_state_dict=text_encoder_lora_state_dict,
|
1336
|
+
peft_config=lora_config,
|
1337
|
+
)
|
1338
|
+
|
1339
|
+
# scale LoRA layers with `lora_scale`
|
1340
|
+
scale_lora_layers(text_encoder, weight=lora_scale)
|
1341
|
+
|
1342
|
+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
1343
|
+
|
1344
|
+
# Offload back.
|
1345
|
+
if is_model_cpu_offload:
|
1346
|
+
_pipeline.enable_model_cpu_offload()
|
1347
|
+
elif is_sequential_cpu_offload:
|
1348
|
+
_pipeline.enable_sequential_cpu_offload()
|
1349
|
+
# Unsafe code />
|
1350
|
+
|
1351
|
+
@classmethod
|
1352
|
+
def save_lora_weights(
|
1353
|
+
cls,
|
1354
|
+
save_directory: Union[str, os.PathLike],
|
1355
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
1356
|
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1357
|
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1358
|
+
is_main_process: bool = True,
|
1359
|
+
weight_name: str = None,
|
1360
|
+
save_function: Callable = None,
|
1361
|
+
safe_serialization: bool = True,
|
1362
|
+
):
|
1363
|
+
r"""
|
1364
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1365
|
+
|
1366
|
+
Arguments:
|
1367
|
+
save_directory (`str` or `os.PathLike`):
|
1368
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1369
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1370
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1371
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1372
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1373
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1374
|
+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1375
|
+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
1376
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1377
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1378
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1379
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1380
|
+
process to avoid race conditions.
|
1381
|
+
save_function (`Callable`):
|
1382
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1383
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1384
|
+
`DIFFUSERS_SAVE_MODE`.
|
1385
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1386
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1387
|
+
"""
|
1388
|
+
state_dict = {}
|
1389
|
+
|
1390
|
+
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
1391
|
+
raise ValueError(
|
1392
|
+
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
1393
|
+
)
|
1394
|
+
|
1395
|
+
if transformer_lora_layers:
|
1396
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
1397
|
+
|
1398
|
+
if text_encoder_lora_layers:
|
1399
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
1400
|
+
|
1401
|
+
if text_encoder_2_lora_layers:
|
1402
|
+
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1403
|
+
|
1404
|
+
# Save the model
|
1405
|
+
cls.write_lora_layers(
|
1406
|
+
state_dict=state_dict,
|
1407
|
+
save_directory=save_directory,
|
1408
|
+
is_main_process=is_main_process,
|
1409
|
+
weight_name=weight_name,
|
1410
|
+
save_function=save_function,
|
1411
|
+
safe_serialization=safe_serialization,
|
1412
|
+
)
|
1413
|
+
|
1414
|
+
def fuse_lora(
|
1415
|
+
self,
|
1416
|
+
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
1417
|
+
lora_scale: float = 1.0,
|
1418
|
+
safe_fusing: bool = False,
|
1419
|
+
adapter_names: Optional[List[str]] = None,
|
1420
|
+
**kwargs,
|
1421
|
+
):
|
1422
|
+
r"""
|
1423
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
1424
|
+
|
1425
|
+
<Tip warning={true}>
|
1426
|
+
|
1427
|
+
This is an experimental API.
|
1428
|
+
|
1429
|
+
</Tip>
|
1430
|
+
|
1431
|
+
Args:
|
1432
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
1433
|
+
lora_scale (`float`, defaults to 1.0):
|
1434
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
1435
|
+
safe_fusing (`bool`, defaults to `False`):
|
1436
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1437
|
+
adapter_names (`List[str]`, *optional*):
|
1438
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1439
|
+
|
1440
|
+
Example:
|
1441
|
+
|
1442
|
+
```py
|
1443
|
+
from diffusers import DiffusionPipeline
|
1444
|
+
import torch
|
1445
|
+
|
1446
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1447
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1448
|
+
).to("cuda")
|
1449
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1450
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1451
|
+
```
|
1452
|
+
"""
|
1453
|
+
super().fuse_lora(
|
1454
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
1455
|
+
)
|
1456
|
+
|
1457
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
1458
|
+
r"""
|
1459
|
+
Reverses the effect of
|
1460
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
1461
|
+
|
1462
|
+
<Tip warning={true}>
|
1463
|
+
|
1464
|
+
This is an experimental API.
|
1465
|
+
|
1466
|
+
</Tip>
|
1467
|
+
|
1468
|
+
Args:
|
1469
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1470
|
+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
1471
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
1472
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
1473
|
+
LoRA parameters then it won't have any effect.
|
1474
|
+
"""
|
1475
|
+
super().unfuse_lora(components=components)
|
1476
|
+
|
1477
|
+
|
1478
|
+
class FluxLoraLoaderMixin(LoraBaseMixin):
|
1479
|
+
r"""
|
1480
|
+
Load LoRA layers into [`FluxTransformer2DModel`],
|
1481
|
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
1482
|
+
|
1483
|
+
Specific to [`StableDiffusion3Pipeline`].
|
1484
|
+
"""
|
1485
|
+
|
1486
|
+
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1487
|
+
transformer_name = TRANSFORMER_NAME
|
1488
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
1489
|
+
|
1490
|
+
@classmethod
|
1491
|
+
@validate_hf_hub_args
|
1492
|
+
def lora_state_dict(
|
1493
|
+
cls,
|
1494
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
1495
|
+
return_alphas: bool = False,
|
1496
|
+
**kwargs,
|
1497
|
+
):
|
1498
|
+
r"""
|
1499
|
+
Return state dict for lora weights and the network alphas.
|
1500
|
+
|
1501
|
+
<Tip warning={true}>
|
1502
|
+
|
1503
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
1504
|
+
|
1505
|
+
This function is experimental and might change in the future.
|
1506
|
+
|
1507
|
+
</Tip>
|
1508
|
+
|
1509
|
+
Parameters:
|
1510
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1511
|
+
Can be either:
|
1512
|
+
|
1513
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
1514
|
+
the Hub.
|
1515
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
1516
|
+
with [`ModelMixin.save_pretrained`].
|
1517
|
+
- A [torch state
|
1518
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
1519
|
+
|
1520
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1521
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1522
|
+
is not used.
|
1523
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1524
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1525
|
+
cached versions if they exist.
|
1526
|
+
|
1527
|
+
proxies (`Dict[str, str]`, *optional*):
|
1528
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1529
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1530
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1531
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
1532
|
+
won't be downloaded from the Hub.
|
1533
|
+
token (`str` or *bool*, *optional*):
|
1534
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1535
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1536
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1537
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1538
|
+
allowed by Git.
|
1539
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
1540
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
1541
|
+
|
1542
|
+
"""
|
1543
|
+
# Load the main state dict first which has the LoRA layers for either of
|
1544
|
+
# transformer and text encoder or both.
|
1545
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
1546
|
+
force_download = kwargs.pop("force_download", False)
|
1547
|
+
proxies = kwargs.pop("proxies", None)
|
1548
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
1549
|
+
token = kwargs.pop("token", None)
|
1550
|
+
revision = kwargs.pop("revision", None)
|
1551
|
+
subfolder = kwargs.pop("subfolder", None)
|
1552
|
+
weight_name = kwargs.pop("weight_name", None)
|
1553
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1554
|
+
|
1555
|
+
allow_pickle = False
|
1556
|
+
if use_safetensors is None:
|
1557
|
+
use_safetensors = True
|
1558
|
+
allow_pickle = True
|
1559
|
+
|
1560
|
+
user_agent = {
|
1561
|
+
"file_type": "attn_procs_weights",
|
1562
|
+
"framework": "pytorch",
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
state_dict = cls._fetch_state_dict(
|
1566
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1567
|
+
weight_name=weight_name,
|
1568
|
+
use_safetensors=use_safetensors,
|
1569
|
+
local_files_only=local_files_only,
|
1570
|
+
cache_dir=cache_dir,
|
1571
|
+
force_download=force_download,
|
1572
|
+
proxies=proxies,
|
1573
|
+
token=token,
|
1574
|
+
revision=revision,
|
1575
|
+
subfolder=subfolder,
|
1576
|
+
user_agent=user_agent,
|
1577
|
+
allow_pickle=allow_pickle,
|
1578
|
+
)
|
1579
|
+
|
1580
|
+
# For state dicts like
|
1581
|
+
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
1582
|
+
keys = list(state_dict.keys())
|
1583
|
+
network_alphas = {}
|
1584
|
+
for k in keys:
|
1585
|
+
if "alpha" in k:
|
1586
|
+
alpha_value = state_dict.get(k)
|
1587
|
+
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
1588
|
+
alpha_value, float
|
1589
|
+
):
|
1590
|
+
network_alphas[k] = state_dict.pop(k)
|
1591
|
+
else:
|
1592
|
+
raise ValueError(
|
1593
|
+
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
1594
|
+
)
|
1595
|
+
|
1596
|
+
if return_alphas:
|
1597
|
+
return state_dict, network_alphas
|
1598
|
+
else:
|
1599
|
+
return state_dict
|
1600
|
+
|
1601
|
+
def load_lora_weights(
|
1602
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
1603
|
+
):
|
1604
|
+
"""
|
1605
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
1606
|
+
`self.text_encoder`.
|
1607
|
+
|
1608
|
+
All kwargs are forwarded to `self.lora_state_dict`.
|
1609
|
+
|
1610
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
1611
|
+
loaded.
|
1612
|
+
|
1613
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
1614
|
+
dict is loaded into `self.transformer`.
|
1615
|
+
|
1616
|
+
Parameters:
|
1617
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1618
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1619
|
+
kwargs (`dict`, *optional*):
|
1620
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1621
|
+
adapter_name (`str`, *optional*):
|
1622
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1623
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1624
|
+
"""
|
1625
|
+
if not USE_PEFT_BACKEND:
|
1626
|
+
raise ValueError("PEFT backend is required for this method.")
|
1627
|
+
|
1628
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
1629
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1630
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
1631
|
+
|
1632
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1633
|
+
state_dict, network_alphas = self.lora_state_dict(
|
1634
|
+
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1635
|
+
)
|
1636
|
+
|
1637
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
1638
|
+
if not is_correct_format:
|
1639
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
1640
|
+
|
1641
|
+
self.load_lora_into_transformer(
|
1642
|
+
state_dict,
|
1643
|
+
network_alphas=network_alphas,
|
1644
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
1645
|
+
adapter_name=adapter_name,
|
1646
|
+
_pipeline=self,
|
1647
|
+
)
|
1648
|
+
|
1649
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1650
|
+
if len(text_encoder_state_dict) > 0:
|
1651
|
+
self.load_lora_into_text_encoder(
|
1652
|
+
text_encoder_state_dict,
|
1653
|
+
network_alphas=network_alphas,
|
1654
|
+
text_encoder=self.text_encoder,
|
1655
|
+
prefix="text_encoder",
|
1656
|
+
lora_scale=self.lora_scale,
|
1657
|
+
adapter_name=adapter_name,
|
1658
|
+
_pipeline=self,
|
1659
|
+
)
|
1660
|
+
|
1661
|
+
@classmethod
|
1662
|
+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
1663
|
+
"""
|
1664
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1665
|
+
|
1666
|
+
Parameters:
|
1667
|
+
state_dict (`dict`):
|
1668
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1669
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1670
|
+
encoder lora layers.
|
1671
|
+
network_alphas (`Dict[str, float]`):
|
1672
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1673
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1674
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1675
|
+
transformer (`SD3Transformer2DModel`):
|
1676
|
+
The Transformer model to load the LoRA layers into.
|
1677
|
+
adapter_name (`str`, *optional*):
|
1678
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1679
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1680
|
+
"""
|
1681
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
1682
|
+
|
1683
|
+
keys = list(state_dict.keys())
|
1684
|
+
|
1685
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
1686
|
+
state_dict = {
|
1687
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
1688
|
+
}
|
1689
|
+
|
1690
|
+
if len(state_dict.keys()) > 0:
|
1691
|
+
# check with first key if is not in peft format
|
1692
|
+
first_key = next(iter(state_dict.keys()))
|
1693
|
+
if "lora_A" not in first_key:
|
1694
|
+
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
1695
|
+
|
1696
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
1697
|
+
raise ValueError(
|
1698
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
1699
|
+
)
|
1700
|
+
|
1701
|
+
rank = {}
|
1702
|
+
for key, val in state_dict.items():
|
1703
|
+
if "lora_B" in key:
|
1704
|
+
rank[key] = val.shape[1]
|
1705
|
+
|
1706
|
+
if network_alphas is not None and len(network_alphas) >= 1:
|
1707
|
+
prefix = cls.transformer_name
|
1708
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1709
|
+
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
1710
|
+
|
1711
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
1712
|
+
if "use_dora" in lora_config_kwargs:
|
1713
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1714
|
+
raise ValueError(
|
1715
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1716
|
+
)
|
1717
|
+
else:
|
1718
|
+
lora_config_kwargs.pop("use_dora")
|
1719
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
1720
|
+
|
1721
|
+
# adapter_name
|
1722
|
+
if adapter_name is None:
|
1723
|
+
adapter_name = get_adapter_name(transformer)
|
1724
|
+
|
1725
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1726
|
+
# otherwise loading LoRA weights will lead to an error
|
1727
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1728
|
+
|
1729
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
1730
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
1731
|
+
|
1732
|
+
if incompatible_keys is not None:
|
1733
|
+
# check only for unexpected keys
|
1734
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1735
|
+
if unexpected_keys:
|
1736
|
+
logger.warning(
|
1737
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1738
|
+
f" {unexpected_keys}. "
|
1739
|
+
)
|
1740
|
+
|
1741
|
+
# Offload back.
|
1742
|
+
if is_model_cpu_offload:
|
1743
|
+
_pipeline.enable_model_cpu_offload()
|
1744
|
+
elif is_sequential_cpu_offload:
|
1745
|
+
_pipeline.enable_sequential_cpu_offload()
|
1746
|
+
# Unsafe code />
|
1747
|
+
|
1748
|
+
@classmethod
|
1749
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
1750
|
+
def load_lora_into_text_encoder(
|
1751
|
+
cls,
|
1752
|
+
state_dict,
|
1753
|
+
network_alphas,
|
1754
|
+
text_encoder,
|
1755
|
+
prefix=None,
|
1756
|
+
lora_scale=1.0,
|
1757
|
+
adapter_name=None,
|
1758
|
+
_pipeline=None,
|
1759
|
+
):
|
1760
|
+
"""
|
1761
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
1762
|
+
|
1763
|
+
Parameters:
|
1764
|
+
state_dict (`dict`):
|
1765
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1766
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
1767
|
+
network_alphas (`Dict[str, float]`):
|
1768
|
+
See `LoRALinearLayer` for more details.
|
1769
|
+
text_encoder (`CLIPTextModel`):
|
1770
|
+
The text encoder model to load the LoRA layers into.
|
1771
|
+
prefix (`str`):
|
1772
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
1773
|
+
lora_scale (`float`):
|
1774
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
1775
|
+
lora layer.
|
1776
|
+
adapter_name (`str`, *optional*):
|
1777
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1778
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
1779
|
+
"""
|
1780
|
+
if not USE_PEFT_BACKEND:
|
1781
|
+
raise ValueError("PEFT backend is required for this method.")
|
1782
|
+
|
1783
|
+
from peft import LoraConfig
|
1784
|
+
|
1785
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
1786
|
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
1787
|
+
# their prefixes.
|
1788
|
+
keys = list(state_dict.keys())
|
1789
|
+
prefix = cls.text_encoder_name if prefix is None else prefix
|
1790
|
+
|
1791
|
+
# Safe prefix to check with.
|
1792
|
+
if any(cls.text_encoder_name in key for key in keys):
|
1793
|
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
1794
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1795
|
+
text_encoder_lora_state_dict = {
|
1796
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
1797
|
+
}
|
1798
|
+
|
1799
|
+
if len(text_encoder_lora_state_dict) > 0:
|
1800
|
+
logger.info(f"Loading {prefix}.")
|
1801
|
+
rank = {}
|
1802
|
+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
1803
|
+
|
1804
|
+
# convert state dict
|
1805
|
+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
1806
|
+
|
1807
|
+
for name, _ in text_encoder_attn_modules(text_encoder):
|
1808
|
+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
1809
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
1810
|
+
if rank_key not in text_encoder_lora_state_dict:
|
1811
|
+
continue
|
1812
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1813
|
+
|
1814
|
+
for name, _ in text_encoder_mlp_modules(text_encoder):
|
1815
|
+
for module in ("fc1", "fc2"):
|
1816
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
1817
|
+
if rank_key not in text_encoder_lora_state_dict:
|
1818
|
+
continue
|
1819
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
1820
|
+
|
1821
|
+
if network_alphas is not None:
|
1822
|
+
alpha_keys = [
|
1823
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
1824
|
+
]
|
1825
|
+
network_alphas = {
|
1826
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
1827
|
+
}
|
1828
|
+
|
1829
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
1830
|
+
if "use_dora" in lora_config_kwargs:
|
1831
|
+
if lora_config_kwargs["use_dora"]:
|
1832
|
+
if is_peft_version("<", "0.9.0"):
|
1833
|
+
raise ValueError(
|
1834
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1835
|
+
)
|
1836
|
+
else:
|
1837
|
+
if is_peft_version("<", "0.9.0"):
|
1838
|
+
lora_config_kwargs.pop("use_dora")
|
1839
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
1840
|
+
|
1841
|
+
# adapter_name
|
1842
|
+
if adapter_name is None:
|
1843
|
+
adapter_name = get_adapter_name(text_encoder)
|
1844
|
+
|
1845
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1846
|
+
|
1847
|
+
# inject LoRA layers and load the state dict
|
1848
|
+
# in transformers we automatically check whether the adapter name is already in use or not
|
1849
|
+
text_encoder.load_adapter(
|
1850
|
+
adapter_name=adapter_name,
|
1851
|
+
adapter_state_dict=text_encoder_lora_state_dict,
|
1852
|
+
peft_config=lora_config,
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
# scale LoRA layers with `lora_scale`
|
1856
|
+
scale_lora_layers(text_encoder, weight=lora_scale)
|
1857
|
+
|
1858
|
+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
1859
|
+
|
1860
|
+
# Offload back.
|
1861
|
+
if is_model_cpu_offload:
|
1862
|
+
_pipeline.enable_model_cpu_offload()
|
1863
|
+
elif is_sequential_cpu_offload:
|
1864
|
+
_pipeline.enable_sequential_cpu_offload()
|
1865
|
+
# Unsafe code />
|
1866
|
+
|
1867
|
+
@classmethod
|
1868
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer
|
1869
|
+
def save_lora_weights(
|
1870
|
+
cls,
|
1871
|
+
save_directory: Union[str, os.PathLike],
|
1872
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1873
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
1874
|
+
is_main_process: bool = True,
|
1875
|
+
weight_name: str = None,
|
1876
|
+
save_function: Callable = None,
|
1877
|
+
safe_serialization: bool = True,
|
1878
|
+
):
|
1879
|
+
r"""
|
1880
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1881
|
+
|
1882
|
+
Arguments:
|
1883
|
+
save_directory (`str` or `os.PathLike`):
|
1884
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1885
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1886
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
1887
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1888
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1889
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1890
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1891
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1892
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1893
|
+
process to avoid race conditions.
|
1894
|
+
save_function (`Callable`):
|
1895
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1896
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1897
|
+
`DIFFUSERS_SAVE_MODE`.
|
1898
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1899
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1900
|
+
"""
|
1901
|
+
state_dict = {}
|
1902
|
+
|
1903
|
+
if not (transformer_lora_layers or text_encoder_lora_layers):
|
1904
|
+
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
1905
|
+
|
1906
|
+
if transformer_lora_layers:
|
1907
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
1908
|
+
|
1909
|
+
if text_encoder_lora_layers:
|
1910
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
1911
|
+
|
1912
|
+
# Save the model
|
1913
|
+
cls.write_lora_layers(
|
1914
|
+
state_dict=state_dict,
|
1915
|
+
save_directory=save_directory,
|
1916
|
+
is_main_process=is_main_process,
|
1917
|
+
weight_name=weight_name,
|
1918
|
+
save_function=save_function,
|
1919
|
+
safe_serialization=safe_serialization,
|
1920
|
+
)
|
1921
|
+
|
1922
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
1923
|
+
def fuse_lora(
|
1924
|
+
self,
|
1925
|
+
components: List[str] = ["transformer", "text_encoder"],
|
1926
|
+
lora_scale: float = 1.0,
|
1927
|
+
safe_fusing: bool = False,
|
1928
|
+
adapter_names: Optional[List[str]] = None,
|
1929
|
+
**kwargs,
|
1930
|
+
):
|
1931
|
+
r"""
|
1932
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
1933
|
+
|
1934
|
+
<Tip warning={true}>
|
1935
|
+
|
1936
|
+
This is an experimental API.
|
1937
|
+
|
1938
|
+
</Tip>
|
1939
|
+
|
1940
|
+
Args:
|
1941
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
1942
|
+
lora_scale (`float`, defaults to 1.0):
|
1943
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
1944
|
+
safe_fusing (`bool`, defaults to `False`):
|
1945
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
1946
|
+
adapter_names (`List[str]`, *optional*):
|
1947
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
1948
|
+
|
1949
|
+
Example:
|
1950
|
+
|
1951
|
+
```py
|
1952
|
+
from diffusers import DiffusionPipeline
|
1953
|
+
import torch
|
1954
|
+
|
1955
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
1956
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
1957
|
+
).to("cuda")
|
1958
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
1959
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
1960
|
+
```
|
1961
|
+
"""
|
1962
|
+
super().fuse_lora(
|
1963
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
1964
|
+
)
|
1965
|
+
|
1966
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
1967
|
+
r"""
|
1968
|
+
Reverses the effect of
|
1969
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
1970
|
+
|
1971
|
+
<Tip warning={true}>
|
1972
|
+
|
1973
|
+
This is an experimental API.
|
1974
|
+
|
1975
|
+
</Tip>
|
1976
|
+
|
1977
|
+
Args:
|
1978
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1979
|
+
"""
|
1980
|
+
super().unfuse_lora(components=components)
|
1981
|
+
|
1982
|
+
|
1983
|
+
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
1984
|
+
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
1985
|
+
class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
1986
|
+
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1987
|
+
transformer_name = TRANSFORMER_NAME
|
1988
|
+
text_encoder_name = TEXT_ENCODER_NAME
|
1989
|
+
|
1990
|
+
@classmethod
|
1991
|
+
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
1992
|
+
"""
|
1993
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1994
|
+
|
1995
|
+
Parameters:
|
1996
|
+
state_dict (`dict`):
|
1997
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
1998
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1999
|
+
encoder lora layers.
|
2000
|
+
network_alphas (`Dict[str, float]`):
|
2001
|
+
See `LoRALinearLayer` for more details.
|
2002
|
+
unet (`UNet2DConditionModel`):
|
2003
|
+
The UNet model to load the LoRA layers into.
|
2004
|
+
adapter_name (`str`, *optional*):
|
2005
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2006
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2007
|
+
"""
|
2008
|
+
if not USE_PEFT_BACKEND:
|
2009
|
+
raise ValueError("PEFT backend is required for this method.")
|
2010
|
+
|
2011
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
2012
|
+
|
2013
|
+
keys = list(state_dict.keys())
|
2014
|
+
|
2015
|
+
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
2016
|
+
state_dict = {
|
2017
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
2018
|
+
}
|
2019
|
+
|
2020
|
+
if network_alphas is not None:
|
2021
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)]
|
2022
|
+
network_alphas = {
|
2023
|
+
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
2024
|
+
}
|
2025
|
+
|
2026
|
+
if len(state_dict.keys()) > 0:
|
2027
|
+
if adapter_name in getattr(transformer, "peft_config", {}):
|
2028
|
+
raise ValueError(
|
2029
|
+
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
2030
|
+
)
|
2031
|
+
|
2032
|
+
rank = {}
|
2033
|
+
for key, val in state_dict.items():
|
2034
|
+
if "lora_B" in key:
|
2035
|
+
rank[key] = val.shape[1]
|
2036
|
+
|
2037
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
2038
|
+
if "use_dora" in lora_config_kwargs:
|
2039
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
2040
|
+
raise ValueError(
|
2041
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2042
|
+
)
|
2043
|
+
else:
|
2044
|
+
lora_config_kwargs.pop("use_dora")
|
2045
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
2046
|
+
|
2047
|
+
# adapter_name
|
2048
|
+
if adapter_name is None:
|
2049
|
+
adapter_name = get_adapter_name(transformer)
|
2050
|
+
|
2051
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
2052
|
+
# otherwise loading LoRA weights will lead to an error
|
2053
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2054
|
+
|
2055
|
+
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
2056
|
+
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
2057
|
+
|
2058
|
+
if incompatible_keys is not None:
|
2059
|
+
# check only for unexpected keys
|
2060
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2061
|
+
if unexpected_keys:
|
2062
|
+
logger.warning(
|
2063
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
2064
|
+
f" {unexpected_keys}. "
|
2065
|
+
)
|
2066
|
+
|
2067
|
+
# Offload back.
|
2068
|
+
if is_model_cpu_offload:
|
2069
|
+
_pipeline.enable_model_cpu_offload()
|
2070
|
+
elif is_sequential_cpu_offload:
|
2071
|
+
_pipeline.enable_sequential_cpu_offload()
|
2072
|
+
# Unsafe code />
|
2073
|
+
|
2074
|
+
@classmethod
|
2075
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
2076
|
+
def load_lora_into_text_encoder(
|
2077
|
+
cls,
|
2078
|
+
state_dict,
|
2079
|
+
network_alphas,
|
2080
|
+
text_encoder,
|
2081
|
+
prefix=None,
|
2082
|
+
lora_scale=1.0,
|
2083
|
+
adapter_name=None,
|
2084
|
+
_pipeline=None,
|
2085
|
+
):
|
2086
|
+
"""
|
2087
|
+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
2088
|
+
|
2089
|
+
Parameters:
|
2090
|
+
state_dict (`dict`):
|
2091
|
+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
2092
|
+
additional `text_encoder` to distinguish between unet lora layers.
|
2093
|
+
network_alphas (`Dict[str, float]`):
|
2094
|
+
See `LoRALinearLayer` for more details.
|
2095
|
+
text_encoder (`CLIPTextModel`):
|
2096
|
+
The text encoder model to load the LoRA layers into.
|
2097
|
+
prefix (`str`):
|
2098
|
+
Expected prefix of the `text_encoder` in the `state_dict`.
|
2099
|
+
lora_scale (`float`):
|
2100
|
+
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
2101
|
+
lora layer.
|
2102
|
+
adapter_name (`str`, *optional*):
|
2103
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2104
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2105
|
+
"""
|
2106
|
+
if not USE_PEFT_BACKEND:
|
2107
|
+
raise ValueError("PEFT backend is required for this method.")
|
2108
|
+
|
2109
|
+
from peft import LoraConfig
|
2110
|
+
|
2111
|
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
2112
|
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
2113
|
+
# their prefixes.
|
2114
|
+
keys = list(state_dict.keys())
|
2115
|
+
prefix = cls.text_encoder_name if prefix is None else prefix
|
2116
|
+
|
2117
|
+
# Safe prefix to check with.
|
2118
|
+
if any(cls.text_encoder_name in key for key in keys):
|
2119
|
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
2120
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
2121
|
+
text_encoder_lora_state_dict = {
|
2122
|
+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
2123
|
+
}
|
2124
|
+
|
2125
|
+
if len(text_encoder_lora_state_dict) > 0:
|
2126
|
+
logger.info(f"Loading {prefix}.")
|
2127
|
+
rank = {}
|
2128
|
+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
2129
|
+
|
2130
|
+
# convert state dict
|
2131
|
+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
2132
|
+
|
2133
|
+
for name, _ in text_encoder_attn_modules(text_encoder):
|
2134
|
+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
2135
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
2136
|
+
if rank_key not in text_encoder_lora_state_dict:
|
2137
|
+
continue
|
2138
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2139
|
+
|
2140
|
+
for name, _ in text_encoder_mlp_modules(text_encoder):
|
2141
|
+
for module in ("fc1", "fc2"):
|
2142
|
+
rank_key = f"{name}.{module}.lora_B.weight"
|
2143
|
+
if rank_key not in text_encoder_lora_state_dict:
|
2144
|
+
continue
|
2145
|
+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
2146
|
+
|
2147
|
+
if network_alphas is not None:
|
2148
|
+
alpha_keys = [
|
2149
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
2150
|
+
]
|
2151
|
+
network_alphas = {
|
2152
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
2153
|
+
}
|
2154
|
+
|
2155
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2156
|
+
if "use_dora" in lora_config_kwargs:
|
2157
|
+
if lora_config_kwargs["use_dora"]:
|
2158
|
+
if is_peft_version("<", "0.9.0"):
|
2159
|
+
raise ValueError(
|
2160
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2161
|
+
)
|
2162
|
+
else:
|
2163
|
+
if is_peft_version("<", "0.9.0"):
|
2164
|
+
lora_config_kwargs.pop("use_dora")
|
2165
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
2166
|
+
|
2167
|
+
# adapter_name
|
2168
|
+
if adapter_name is None:
|
2169
|
+
adapter_name = get_adapter_name(text_encoder)
|
2170
|
+
|
2171
|
+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2172
|
+
|
2173
|
+
# inject LoRA layers and load the state dict
|
2174
|
+
# in transformers we automatically check whether the adapter name is already in use or not
|
2175
|
+
text_encoder.load_adapter(
|
2176
|
+
adapter_name=adapter_name,
|
2177
|
+
adapter_state_dict=text_encoder_lora_state_dict,
|
2178
|
+
peft_config=lora_config,
|
2179
|
+
)
|
2180
|
+
|
2181
|
+
# scale LoRA layers with `lora_scale`
|
2182
|
+
scale_lora_layers(text_encoder, weight=lora_scale)
|
2183
|
+
|
2184
|
+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
2185
|
+
|
2186
|
+
# Offload back.
|
2187
|
+
if is_model_cpu_offload:
|
2188
|
+
_pipeline.enable_model_cpu_offload()
|
2189
|
+
elif is_sequential_cpu_offload:
|
2190
|
+
_pipeline.enable_sequential_cpu_offload()
|
2191
|
+
# Unsafe code />
|
2192
|
+
|
2193
|
+
@classmethod
|
2194
|
+
def save_lora_weights(
|
2195
|
+
cls,
|
2196
|
+
save_directory: Union[str, os.PathLike],
|
2197
|
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
2198
|
+
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
2199
|
+
is_main_process: bool = True,
|
2200
|
+
weight_name: str = None,
|
2201
|
+
save_function: Callable = None,
|
2202
|
+
safe_serialization: bool = True,
|
2203
|
+
):
|
2204
|
+
r"""
|
2205
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2206
|
+
|
2207
|
+
Arguments:
|
2208
|
+
save_directory (`str` or `os.PathLike`):
|
2209
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2210
|
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2211
|
+
State dict of the LoRA layers corresponding to the `unet`.
|
2212
|
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2213
|
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
2214
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
2215
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2216
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2217
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2218
|
+
process to avoid race conditions.
|
2219
|
+
save_function (`Callable`):
|
2220
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2221
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2222
|
+
`DIFFUSERS_SAVE_MODE`.
|
2223
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2224
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2225
|
+
"""
|
2226
|
+
state_dict = {}
|
2227
|
+
|
2228
|
+
if not (transformer_lora_layers or text_encoder_lora_layers):
|
2229
|
+
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
2230
|
+
|
2231
|
+
if transformer_lora_layers:
|
2232
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2233
|
+
|
2234
|
+
if text_encoder_lora_layers:
|
2235
|
+
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
2236
|
+
|
2237
|
+
# Save the model
|
2238
|
+
cls.write_lora_layers(
|
2239
|
+
state_dict=state_dict,
|
2240
|
+
save_directory=save_directory,
|
2241
|
+
is_main_process=is_main_process,
|
2242
|
+
weight_name=weight_name,
|
2243
|
+
save_function=save_function,
|
2244
|
+
safe_serialization=safe_serialization,
|
2245
|
+
)
|
2246
|
+
|
2247
|
+
|
2248
|
+
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
2249
|
+
def __init__(self, *args, **kwargs):
|
2250
|
+
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
2251
|
+
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
|
2252
|
+
super().__init__(*args, **kwargs)
|