diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
14
15
|
import os
|
15
16
|
from typing import Callable, Dict, List, Optional, Union
|
16
17
|
|
@@ -21,17 +22,36 @@ from ..utils import (
|
|
21
22
|
USE_PEFT_BACKEND,
|
22
23
|
convert_state_dict_to_diffusers,
|
23
24
|
convert_state_dict_to_peft,
|
24
|
-
convert_unet_state_dict_to_peft,
|
25
25
|
deprecate,
|
26
26
|
get_adapter_name,
|
27
27
|
get_peft_kwargs,
|
28
|
+
is_peft_available,
|
28
29
|
is_peft_version,
|
30
|
+
is_torch_version,
|
29
31
|
is_transformers_available,
|
32
|
+
is_transformers_version,
|
30
33
|
logging,
|
31
34
|
scale_lora_layers,
|
32
35
|
)
|
33
|
-
from .lora_base import LoraBaseMixin
|
34
|
-
from .lora_conversion_utils import
|
36
|
+
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
|
37
|
+
from .lora_conversion_utils import (
|
38
|
+
_convert_bfl_flux_control_lora_to_diffusers,
|
39
|
+
_convert_kohya_flux_lora_to_diffusers,
|
40
|
+
_convert_non_diffusers_lora_to_diffusers,
|
41
|
+
_convert_xlabs_flux_lora_to_diffusers,
|
42
|
+
_maybe_map_sgm_blocks_to_diffusers,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
47
|
+
if is_torch_version(">=", "1.9.0"):
|
48
|
+
if (
|
49
|
+
is_peft_available()
|
50
|
+
and is_peft_version(">=", "0.13.1")
|
51
|
+
and is_transformers_available()
|
52
|
+
and is_transformers_version(">", "4.45.2")
|
53
|
+
):
|
54
|
+
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
35
55
|
|
36
56
|
|
37
57
|
if is_transformers_available():
|
@@ -43,8 +63,7 @@ TEXT_ENCODER_NAME = "text_encoder"
|
|
43
63
|
UNET_NAME = "unet"
|
44
64
|
TRANSFORMER_NAME = "transformer"
|
45
65
|
|
46
|
-
|
47
|
-
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
66
|
+
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
48
67
|
|
49
68
|
|
50
69
|
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
@@ -78,15 +97,24 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
78
97
|
Parameters:
|
79
98
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
80
99
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
81
|
-
kwargs (`dict`, *optional*):
|
82
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
83
100
|
adapter_name (`str`, *optional*):
|
84
101
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
85
102
|
`default_{i}` where i is the total number of adapters being loaded.
|
103
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
104
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
105
|
+
weights.
|
106
|
+
kwargs (`dict`, *optional*):
|
107
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
86
108
|
"""
|
87
109
|
if not USE_PEFT_BACKEND:
|
88
110
|
raise ValueError("PEFT backend is required for this method.")
|
89
111
|
|
112
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
113
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
114
|
+
raise ValueError(
|
115
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
116
|
+
)
|
117
|
+
|
90
118
|
# if a dict is passed, copy it instead of modifying it inplace
|
91
119
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
92
120
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -94,7 +122,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
94
122
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
95
123
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
96
124
|
|
97
|
-
is_correct_format = all("lora" in key
|
125
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
98
126
|
if not is_correct_format:
|
99
127
|
raise ValueError("Invalid LoRA checkpoint.")
|
100
128
|
|
@@ -104,6 +132,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
104
132
|
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
105
133
|
adapter_name=adapter_name,
|
106
134
|
_pipeline=self,
|
135
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
107
136
|
)
|
108
137
|
self.load_lora_into_text_encoder(
|
109
138
|
state_dict,
|
@@ -114,6 +143,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
114
143
|
lora_scale=self.lora_scale,
|
115
144
|
adapter_name=adapter_name,
|
116
145
|
_pipeline=self,
|
146
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
117
147
|
)
|
118
148
|
|
119
149
|
@classmethod
|
@@ -192,7 +222,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
192
222
|
"framework": "pytorch",
|
193
223
|
}
|
194
224
|
|
195
|
-
state_dict =
|
225
|
+
state_dict = _fetch_state_dict(
|
196
226
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
197
227
|
weight_name=weight_name,
|
198
228
|
use_safetensors=use_safetensors,
|
@@ -206,6 +236,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
206
236
|
user_agent=user_agent,
|
207
237
|
allow_pickle=allow_pickle,
|
208
238
|
)
|
239
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
240
|
+
if is_dora_scale_present:
|
241
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
242
|
+
logger.warning(warn_msg)
|
243
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
209
244
|
|
210
245
|
network_alphas = None
|
211
246
|
# TODO: replace it with a method from `state_dict_utils`
|
@@ -227,7 +262,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
227
262
|
return state_dict, network_alphas
|
228
263
|
|
229
264
|
@classmethod
|
230
|
-
def load_lora_into_unet(
|
265
|
+
def load_lora_into_unet(
|
266
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
267
|
+
):
|
231
268
|
"""
|
232
269
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
233
270
|
|
@@ -245,10 +282,18 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
245
282
|
adapter_name (`str`, *optional*):
|
246
283
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
247
284
|
`default_{i}` where i is the total number of adapters being loaded.
|
285
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
286
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
287
|
+
weights.
|
248
288
|
"""
|
249
289
|
if not USE_PEFT_BACKEND:
|
250
290
|
raise ValueError("PEFT backend is required for this method.")
|
251
291
|
|
292
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
293
|
+
raise ValueError(
|
294
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
295
|
+
)
|
296
|
+
|
252
297
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
253
298
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
254
299
|
# their prefixes.
|
@@ -257,8 +302,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
257
302
|
if not only_text_encoder:
|
258
303
|
# Load the layers corresponding to UNet.
|
259
304
|
logger.info(f"Loading {cls.unet_name}.")
|
260
|
-
unet.
|
261
|
-
state_dict,
|
305
|
+
unet.load_lora_adapter(
|
306
|
+
state_dict,
|
307
|
+
prefix=cls.unet_name,
|
308
|
+
network_alphas=network_alphas,
|
309
|
+
adapter_name=adapter_name,
|
310
|
+
_pipeline=_pipeline,
|
311
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
262
312
|
)
|
263
313
|
|
264
314
|
@classmethod
|
@@ -271,6 +321,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
271
321
|
lora_scale=1.0,
|
272
322
|
adapter_name=None,
|
273
323
|
_pipeline=None,
|
324
|
+
low_cpu_mem_usage=False,
|
274
325
|
):
|
275
326
|
"""
|
276
327
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -280,7 +331,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
280
331
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
281
332
|
additional `text_encoder` to distinguish between unet lora layers.
|
282
333
|
network_alphas (`Dict[str, float]`):
|
283
|
-
|
334
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
335
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
336
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
284
337
|
text_encoder (`CLIPTextModel`):
|
285
338
|
The text encoder model to load the LoRA layers into.
|
286
339
|
prefix (`str`):
|
@@ -291,10 +344,27 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
291
344
|
adapter_name (`str`, *optional*):
|
292
345
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
293
346
|
`default_{i}` where i is the total number of adapters being loaded.
|
347
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
348
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
349
|
+
weights.
|
294
350
|
"""
|
295
351
|
if not USE_PEFT_BACKEND:
|
296
352
|
raise ValueError("PEFT backend is required for this method.")
|
297
353
|
|
354
|
+
peft_kwargs = {}
|
355
|
+
if low_cpu_mem_usage:
|
356
|
+
if not is_peft_version(">=", "0.13.1"):
|
357
|
+
raise ValueError(
|
358
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
359
|
+
)
|
360
|
+
if not is_transformers_version(">", "4.45.2"):
|
361
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
362
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
363
|
+
raise ValueError(
|
364
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
365
|
+
)
|
366
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
367
|
+
|
298
368
|
from peft import LoraConfig
|
299
369
|
|
300
370
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -342,6 +412,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
342
412
|
}
|
343
413
|
|
344
414
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
415
|
+
|
345
416
|
if "use_dora" in lora_config_kwargs:
|
346
417
|
if lora_config_kwargs["use_dora"]:
|
347
418
|
if is_peft_version("<", "0.9.0"):
|
@@ -351,6 +422,17 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
351
422
|
else:
|
352
423
|
if is_peft_version("<", "0.9.0"):
|
353
424
|
lora_config_kwargs.pop("use_dora")
|
425
|
+
|
426
|
+
if "lora_bias" in lora_config_kwargs:
|
427
|
+
if lora_config_kwargs["lora_bias"]:
|
428
|
+
if is_peft_version("<=", "0.13.2"):
|
429
|
+
raise ValueError(
|
430
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
431
|
+
)
|
432
|
+
else:
|
433
|
+
if is_peft_version("<=", "0.13.2"):
|
434
|
+
lora_config_kwargs.pop("lora_bias")
|
435
|
+
|
354
436
|
lora_config = LoraConfig(**lora_config_kwargs)
|
355
437
|
|
356
438
|
# adapter_name
|
@@ -365,6 +447,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
|
365
447
|
adapter_name=adapter_name,
|
366
448
|
adapter_state_dict=text_encoder_lora_state_dict,
|
367
449
|
peft_config=lora_config,
|
450
|
+
**peft_kwargs,
|
368
451
|
)
|
369
452
|
|
370
453
|
# scale LoRA layers with `lora_scale`
|
@@ -535,12 +618,21 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
535
618
|
adapter_name (`str`, *optional*):
|
536
619
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
537
620
|
`default_{i}` where i is the total number of adapters being loaded.
|
621
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
622
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
623
|
+
weights.
|
538
624
|
kwargs (`dict`, *optional*):
|
539
625
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
540
626
|
"""
|
541
627
|
if not USE_PEFT_BACKEND:
|
542
628
|
raise ValueError("PEFT backend is required for this method.")
|
543
629
|
|
630
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
631
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
632
|
+
raise ValueError(
|
633
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
634
|
+
)
|
635
|
+
|
544
636
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
545
637
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
546
638
|
# pipeline.
|
@@ -555,12 +647,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
555
647
|
unet_config=self.unet.config,
|
556
648
|
**kwargs,
|
557
649
|
)
|
558
|
-
|
650
|
+
|
651
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
559
652
|
if not is_correct_format:
|
560
653
|
raise ValueError("Invalid LoRA checkpoint.")
|
561
654
|
|
562
655
|
self.load_lora_into_unet(
|
563
|
-
state_dict,
|
656
|
+
state_dict,
|
657
|
+
network_alphas=network_alphas,
|
658
|
+
unet=self.unet,
|
659
|
+
adapter_name=adapter_name,
|
660
|
+
_pipeline=self,
|
661
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
564
662
|
)
|
565
663
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
566
664
|
if len(text_encoder_state_dict) > 0:
|
@@ -572,6 +670,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
572
670
|
lora_scale=self.lora_scale,
|
573
671
|
adapter_name=adapter_name,
|
574
672
|
_pipeline=self,
|
673
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
575
674
|
)
|
576
675
|
|
577
676
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
@@ -584,6 +683,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
584
683
|
lora_scale=self.lora_scale,
|
585
684
|
adapter_name=adapter_name,
|
586
685
|
_pipeline=self,
|
686
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
587
687
|
)
|
588
688
|
|
589
689
|
@classmethod
|
@@ -663,7 +763,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
663
763
|
"framework": "pytorch",
|
664
764
|
}
|
665
765
|
|
666
|
-
state_dict =
|
766
|
+
state_dict = _fetch_state_dict(
|
667
767
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
668
768
|
weight_name=weight_name,
|
669
769
|
use_safetensors=use_safetensors,
|
@@ -677,6 +777,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
677
777
|
user_agent=user_agent,
|
678
778
|
allow_pickle=allow_pickle,
|
679
779
|
)
|
780
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
781
|
+
if is_dora_scale_present:
|
782
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
783
|
+
logger.warning(warn_msg)
|
784
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
680
785
|
|
681
786
|
network_alphas = None
|
682
787
|
# TODO: replace it with a method from `state_dict_utils`
|
@@ -699,7 +804,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
699
804
|
|
700
805
|
@classmethod
|
701
806
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
|
702
|
-
def load_lora_into_unet(
|
807
|
+
def load_lora_into_unet(
|
808
|
+
cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
809
|
+
):
|
703
810
|
"""
|
704
811
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
705
812
|
|
@@ -717,10 +824,18 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
717
824
|
adapter_name (`str`, *optional*):
|
718
825
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
719
826
|
`default_{i}` where i is the total number of adapters being loaded.
|
827
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
828
|
+
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
829
|
+
weights.
|
720
830
|
"""
|
721
831
|
if not USE_PEFT_BACKEND:
|
722
832
|
raise ValueError("PEFT backend is required for this method.")
|
723
833
|
|
834
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
835
|
+
raise ValueError(
|
836
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
837
|
+
)
|
838
|
+
|
724
839
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
725
840
|
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
726
841
|
# their prefixes.
|
@@ -729,8 +844,13 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
729
844
|
if not only_text_encoder:
|
730
845
|
# Load the layers corresponding to UNet.
|
731
846
|
logger.info(f"Loading {cls.unet_name}.")
|
732
|
-
unet.
|
733
|
-
state_dict,
|
847
|
+
unet.load_lora_adapter(
|
848
|
+
state_dict,
|
849
|
+
prefix=cls.unet_name,
|
850
|
+
network_alphas=network_alphas,
|
851
|
+
adapter_name=adapter_name,
|
852
|
+
_pipeline=_pipeline,
|
853
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
734
854
|
)
|
735
855
|
|
736
856
|
@classmethod
|
@@ -744,6 +864,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
744
864
|
lora_scale=1.0,
|
745
865
|
adapter_name=None,
|
746
866
|
_pipeline=None,
|
867
|
+
low_cpu_mem_usage=False,
|
747
868
|
):
|
748
869
|
"""
|
749
870
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -753,7 +874,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
753
874
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
754
875
|
additional `text_encoder` to distinguish between unet lora layers.
|
755
876
|
network_alphas (`Dict[str, float]`):
|
756
|
-
|
877
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
878
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
879
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
757
880
|
text_encoder (`CLIPTextModel`):
|
758
881
|
The text encoder model to load the LoRA layers into.
|
759
882
|
prefix (`str`):
|
@@ -764,10 +887,27 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
764
887
|
adapter_name (`str`, *optional*):
|
765
888
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
766
889
|
`default_{i}` where i is the total number of adapters being loaded.
|
890
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
891
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
892
|
+
weights.
|
767
893
|
"""
|
768
894
|
if not USE_PEFT_BACKEND:
|
769
895
|
raise ValueError("PEFT backend is required for this method.")
|
770
896
|
|
897
|
+
peft_kwargs = {}
|
898
|
+
if low_cpu_mem_usage:
|
899
|
+
if not is_peft_version(">=", "0.13.1"):
|
900
|
+
raise ValueError(
|
901
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
902
|
+
)
|
903
|
+
if not is_transformers_version(">", "4.45.2"):
|
904
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
905
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
906
|
+
raise ValueError(
|
907
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
908
|
+
)
|
909
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
910
|
+
|
771
911
|
from peft import LoraConfig
|
772
912
|
|
773
913
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -815,6 +955,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
815
955
|
}
|
816
956
|
|
817
957
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
958
|
+
|
818
959
|
if "use_dora" in lora_config_kwargs:
|
819
960
|
if lora_config_kwargs["use_dora"]:
|
820
961
|
if is_peft_version("<", "0.9.0"):
|
@@ -824,6 +965,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
824
965
|
else:
|
825
966
|
if is_peft_version("<", "0.9.0"):
|
826
967
|
lora_config_kwargs.pop("use_dora")
|
968
|
+
|
969
|
+
if "lora_bias" in lora_config_kwargs:
|
970
|
+
if lora_config_kwargs["lora_bias"]:
|
971
|
+
if is_peft_version("<=", "0.13.2"):
|
972
|
+
raise ValueError(
|
973
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
974
|
+
)
|
975
|
+
else:
|
976
|
+
if is_peft_version("<=", "0.13.2"):
|
977
|
+
lora_config_kwargs.pop("lora_bias")
|
978
|
+
|
827
979
|
lora_config = LoraConfig(**lora_config_kwargs)
|
828
980
|
|
829
981
|
# adapter_name
|
@@ -838,6 +990,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|
838
990
|
adapter_name=adapter_name,
|
839
991
|
adapter_state_dict=text_encoder_lora_state_dict,
|
840
992
|
peft_config=lora_config,
|
993
|
+
**peft_kwargs,
|
841
994
|
)
|
842
995
|
|
843
996
|
# scale LoRA layers with `lora_scale`
|
@@ -1065,7 +1218,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1065
1218
|
"framework": "pytorch",
|
1066
1219
|
}
|
1067
1220
|
|
1068
|
-
state_dict =
|
1221
|
+
state_dict = _fetch_state_dict(
|
1069
1222
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1070
1223
|
weight_name=weight_name,
|
1071
1224
|
use_safetensors=use_safetensors,
|
@@ -1080,6 +1233,12 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1080
1233
|
allow_pickle=allow_pickle,
|
1081
1234
|
)
|
1082
1235
|
|
1236
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1237
|
+
if is_dora_scale_present:
|
1238
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1239
|
+
logger.warning(warn_msg)
|
1240
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1241
|
+
|
1083
1242
|
return state_dict
|
1084
1243
|
|
1085
1244
|
def load_lora_weights(
|
@@ -1100,15 +1259,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1100
1259
|
Parameters:
|
1101
1260
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
1102
1261
|
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1103
|
-
kwargs (`dict`, *optional*):
|
1104
|
-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1105
1262
|
adapter_name (`str`, *optional*):
|
1106
1263
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1107
1264
|
`default_{i}` where i is the total number of adapters being loaded.
|
1265
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1266
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1267
|
+
weights.
|
1268
|
+
kwargs (`dict`, *optional*):
|
1269
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
1108
1270
|
"""
|
1109
1271
|
if not USE_PEFT_BACKEND:
|
1110
1272
|
raise ValueError("PEFT backend is required for this method.")
|
1111
1273
|
|
1274
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1275
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1276
|
+
raise ValueError(
|
1277
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1278
|
+
)
|
1279
|
+
|
1112
1280
|
# if a dict is passed, copy it instead of modifying it inplace
|
1113
1281
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1114
1282
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -1116,16 +1284,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1116
1284
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
1117
1285
|
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
1118
1286
|
|
1119
|
-
is_correct_format = all("lora" in key
|
1287
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1120
1288
|
if not is_correct_format:
|
1121
1289
|
raise ValueError("Invalid LoRA checkpoint.")
|
1122
1290
|
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1291
|
+
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
1292
|
+
if len(transformer_state_dict) > 0:
|
1293
|
+
self.load_lora_into_transformer(
|
1294
|
+
state_dict,
|
1295
|
+
transformer=getattr(self, self.transformer_name)
|
1296
|
+
if not hasattr(self, "transformer")
|
1297
|
+
else self.transformer,
|
1298
|
+
adapter_name=adapter_name,
|
1299
|
+
_pipeline=self,
|
1300
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1301
|
+
)
|
1129
1302
|
|
1130
1303
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1131
1304
|
if len(text_encoder_state_dict) > 0:
|
@@ -1137,6 +1310,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1137
1310
|
lora_scale=self.lora_scale,
|
1138
1311
|
adapter_name=adapter_name,
|
1139
1312
|
_pipeline=self,
|
1313
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1140
1314
|
)
|
1141
1315
|
|
1142
1316
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
@@ -1149,10 +1323,13 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1149
1323
|
lora_scale=self.lora_scale,
|
1150
1324
|
adapter_name=adapter_name,
|
1151
1325
|
_pipeline=self,
|
1326
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1152
1327
|
)
|
1153
1328
|
|
1154
1329
|
@classmethod
|
1155
|
-
def load_lora_into_transformer(
|
1330
|
+
def load_lora_into_transformer(
|
1331
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1332
|
+
):
|
1156
1333
|
"""
|
1157
1334
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1158
1335
|
|
@@ -1166,68 +1343,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1166
1343
|
adapter_name (`str`, *optional*):
|
1167
1344
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1168
1345
|
`default_{i}` where i is the total number of adapters being loaded.
|
1346
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1347
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1348
|
+
weights.
|
1169
1349
|
"""
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
1175
|
-
state_dict = {
|
1176
|
-
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
1177
|
-
}
|
1178
|
-
|
1179
|
-
if len(state_dict.keys()) > 0:
|
1180
|
-
# check with first key if is not in peft format
|
1181
|
-
first_key = next(iter(state_dict.keys()))
|
1182
|
-
if "lora_A" not in first_key:
|
1183
|
-
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
1184
|
-
|
1185
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
1186
|
-
raise ValueError(
|
1187
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
1188
|
-
)
|
1189
|
-
|
1190
|
-
rank = {}
|
1191
|
-
for key, val in state_dict.items():
|
1192
|
-
if "lora_B" in key:
|
1193
|
-
rank[key] = val.shape[1]
|
1194
|
-
|
1195
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
1196
|
-
if "use_dora" in lora_config_kwargs:
|
1197
|
-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
1198
|
-
raise ValueError(
|
1199
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
1200
|
-
)
|
1201
|
-
else:
|
1202
|
-
lora_config_kwargs.pop("use_dora")
|
1203
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
1204
|
-
|
1205
|
-
# adapter_name
|
1206
|
-
if adapter_name is None:
|
1207
|
-
adapter_name = get_adapter_name(transformer)
|
1208
|
-
|
1209
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1210
|
-
# otherwise loading LoRA weights will lead to an error
|
1211
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1212
|
-
|
1213
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
1214
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
1215
|
-
|
1216
|
-
if incompatible_keys is not None:
|
1217
|
-
# check only for unexpected keys
|
1218
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1219
|
-
if unexpected_keys:
|
1220
|
-
logger.warning(
|
1221
|
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1222
|
-
f" {unexpected_keys}. "
|
1223
|
-
)
|
1350
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
1351
|
+
raise ValueError(
|
1352
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1353
|
+
)
|
1224
1354
|
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1355
|
+
# Load the layers corresponding to transformer.
|
1356
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1357
|
+
transformer.load_lora_adapter(
|
1358
|
+
state_dict,
|
1359
|
+
network_alphas=None,
|
1360
|
+
adapter_name=adapter_name,
|
1361
|
+
_pipeline=_pipeline,
|
1362
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1363
|
+
)
|
1231
1364
|
|
1232
1365
|
@classmethod
|
1233
1366
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -1240,6 +1373,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1240
1373
|
lora_scale=1.0,
|
1241
1374
|
adapter_name=None,
|
1242
1375
|
_pipeline=None,
|
1376
|
+
low_cpu_mem_usage=False,
|
1243
1377
|
):
|
1244
1378
|
"""
|
1245
1379
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1249,7 +1383,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1249
1383
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1250
1384
|
additional `text_encoder` to distinguish between unet lora layers.
|
1251
1385
|
network_alphas (`Dict[str, float]`):
|
1252
|
-
|
1386
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1387
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1388
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1253
1389
|
text_encoder (`CLIPTextModel`):
|
1254
1390
|
The text encoder model to load the LoRA layers into.
|
1255
1391
|
prefix (`str`):
|
@@ -1260,10 +1396,27 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1260
1396
|
adapter_name (`str`, *optional*):
|
1261
1397
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1262
1398
|
`default_{i}` where i is the total number of adapters being loaded.
|
1399
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1400
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1401
|
+
weights.
|
1263
1402
|
"""
|
1264
1403
|
if not USE_PEFT_BACKEND:
|
1265
1404
|
raise ValueError("PEFT backend is required for this method.")
|
1266
1405
|
|
1406
|
+
peft_kwargs = {}
|
1407
|
+
if low_cpu_mem_usage:
|
1408
|
+
if not is_peft_version(">=", "0.13.1"):
|
1409
|
+
raise ValueError(
|
1410
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1411
|
+
)
|
1412
|
+
if not is_transformers_version(">", "4.45.2"):
|
1413
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
1414
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
1415
|
+
raise ValueError(
|
1416
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
1417
|
+
)
|
1418
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
1419
|
+
|
1267
1420
|
from peft import LoraConfig
|
1268
1421
|
|
1269
1422
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -1311,6 +1464,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1311
1464
|
}
|
1312
1465
|
|
1313
1466
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
1467
|
+
|
1314
1468
|
if "use_dora" in lora_config_kwargs:
|
1315
1469
|
if lora_config_kwargs["use_dora"]:
|
1316
1470
|
if is_peft_version("<", "0.9.0"):
|
@@ -1320,6 +1474,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1320
1474
|
else:
|
1321
1475
|
if is_peft_version("<", "0.9.0"):
|
1322
1476
|
lora_config_kwargs.pop("use_dora")
|
1477
|
+
|
1478
|
+
if "lora_bias" in lora_config_kwargs:
|
1479
|
+
if lora_config_kwargs["lora_bias"]:
|
1480
|
+
if is_peft_version("<=", "0.13.2"):
|
1481
|
+
raise ValueError(
|
1482
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
1483
|
+
)
|
1484
|
+
else:
|
1485
|
+
if is_peft_version("<=", "0.13.2"):
|
1486
|
+
lora_config_kwargs.pop("lora_bias")
|
1487
|
+
|
1323
1488
|
lora_config = LoraConfig(**lora_config_kwargs)
|
1324
1489
|
|
1325
1490
|
# adapter_name
|
@@ -1334,6 +1499,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|
1334
1499
|
adapter_name=adapter_name,
|
1335
1500
|
adapter_state_dict=text_encoder_lora_state_dict,
|
1336
1501
|
peft_config=lora_config,
|
1502
|
+
**peft_kwargs,
|
1337
1503
|
)
|
1338
1504
|
|
1339
1505
|
# scale LoRA layers with `lora_scale`
|
@@ -1486,6 +1652,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1486
1652
|
_lora_loadable_modules = ["transformer", "text_encoder"]
|
1487
1653
|
transformer_name = TRANSFORMER_NAME
|
1488
1654
|
text_encoder_name = TEXT_ENCODER_NAME
|
1655
|
+
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
|
1489
1656
|
|
1490
1657
|
@classmethod
|
1491
1658
|
@validate_hf_hub_args
|
@@ -1562,7 +1729,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1562
1729
|
"framework": "pytorch",
|
1563
1730
|
}
|
1564
1731
|
|
1565
|
-
state_dict =
|
1732
|
+
state_dict = _fetch_state_dict(
|
1566
1733
|
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
1567
1734
|
weight_name=weight_name,
|
1568
1735
|
use_safetensors=use_safetensors,
|
@@ -1576,6 +1743,29 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1576
1743
|
user_agent=user_agent,
|
1577
1744
|
allow_pickle=allow_pickle,
|
1578
1745
|
)
|
1746
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
1747
|
+
if is_dora_scale_present:
|
1748
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
1749
|
+
logger.warning(warn_msg)
|
1750
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
1751
|
+
|
1752
|
+
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
1753
|
+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
1754
|
+
if is_kohya:
|
1755
|
+
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
1756
|
+
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
1757
|
+
return (state_dict, None) if return_alphas else state_dict
|
1758
|
+
|
1759
|
+
is_xlabs = any("processor" in k for k in state_dict)
|
1760
|
+
if is_xlabs:
|
1761
|
+
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
1762
|
+
# xlabs doesn't use `alpha`.
|
1763
|
+
return (state_dict, None) if return_alphas else state_dict
|
1764
|
+
|
1765
|
+
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
|
1766
|
+
if is_bfl_control:
|
1767
|
+
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
|
1768
|
+
return (state_dict, None) if return_alphas else state_dict
|
1579
1769
|
|
1580
1770
|
# For state dicts like
|
1581
1771
|
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
@@ -1621,10 +1811,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1621
1811
|
adapter_name (`str`, *optional*):
|
1622
1812
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1623
1813
|
`default_{i}` where i is the total number of adapters being loaded.
|
1814
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1815
|
+
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1816
|
+
weights.
|
1624
1817
|
"""
|
1625
1818
|
if not USE_PEFT_BACKEND:
|
1626
1819
|
raise ValueError("PEFT backend is required for this method.")
|
1627
1820
|
|
1821
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
1822
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1823
|
+
raise ValueError(
|
1824
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1825
|
+
)
|
1826
|
+
|
1628
1827
|
# if a dict is passed, copy it instead of modifying it inplace
|
1629
1828
|
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1630
1829
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
@@ -1634,18 +1833,57 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1634
1833
|
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
1635
1834
|
)
|
1636
1835
|
|
1637
|
-
|
1638
|
-
|
1836
|
+
has_lora_keys = any("lora" in key for key in state_dict.keys())
|
1837
|
+
|
1838
|
+
# Flux Control LoRAs also have norm keys
|
1839
|
+
has_norm_keys = any(
|
1840
|
+
norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys
|
1841
|
+
)
|
1842
|
+
|
1843
|
+
if not (has_lora_keys or has_norm_keys):
|
1639
1844
|
raise ValueError("Invalid LoRA checkpoint.")
|
1640
1845
|
|
1641
|
-
|
1642
|
-
state_dict
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1846
|
+
transformer_lora_state_dict = {
|
1847
|
+
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
|
1848
|
+
}
|
1849
|
+
transformer_norm_state_dict = {
|
1850
|
+
k: state_dict.pop(k)
|
1851
|
+
for k in list(state_dict.keys())
|
1852
|
+
if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
|
1853
|
+
}
|
1854
|
+
|
1855
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
1856
|
+
has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
|
1857
|
+
transformer, transformer_lora_state_dict, transformer_norm_state_dict
|
1858
|
+
)
|
1859
|
+
|
1860
|
+
if has_param_with_expanded_shape:
|
1861
|
+
logger.info(
|
1862
|
+
"The LoRA weights contain parameters that have different shapes that expected by the transformer. "
|
1863
|
+
"As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
|
1864
|
+
"To get a comprehensive list of parameter names that were modified, enable debug logging."
|
1865
|
+
)
|
1866
|
+
transformer_lora_state_dict = self._maybe_expand_lora_state_dict(
|
1867
|
+
transformer=transformer, lora_state_dict=transformer_lora_state_dict
|
1647
1868
|
)
|
1648
1869
|
|
1870
|
+
if len(transformer_lora_state_dict) > 0:
|
1871
|
+
self.load_lora_into_transformer(
|
1872
|
+
transformer_lora_state_dict,
|
1873
|
+
network_alphas=network_alphas,
|
1874
|
+
transformer=transformer,
|
1875
|
+
adapter_name=adapter_name,
|
1876
|
+
_pipeline=self,
|
1877
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1878
|
+
)
|
1879
|
+
|
1880
|
+
if len(transformer_norm_state_dict) > 0:
|
1881
|
+
transformer._transformer_norm_layers = self._load_norm_into_transformer(
|
1882
|
+
transformer_norm_state_dict,
|
1883
|
+
transformer=transformer,
|
1884
|
+
discard_original_layers=False,
|
1885
|
+
)
|
1886
|
+
|
1649
1887
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1650
1888
|
if len(text_encoder_state_dict) > 0:
|
1651
1889
|
self.load_lora_into_text_encoder(
|
@@ -1656,10 +1894,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1656
1894
|
lora_scale=self.lora_scale,
|
1657
1895
|
adapter_name=adapter_name,
|
1658
1896
|
_pipeline=self,
|
1897
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1659
1898
|
)
|
1660
1899
|
|
1661
1900
|
@classmethod
|
1662
|
-
def load_lora_into_transformer(
|
1901
|
+
def load_lora_into_transformer(
|
1902
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
1903
|
+
):
|
1663
1904
|
"""
|
1664
1905
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1665
1906
|
|
@@ -1672,78 +1913,86 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1672
1913
|
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
1673
1914
|
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
1674
1915
|
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1675
|
-
transformer (`
|
1916
|
+
transformer (`FluxTransformer2DModel`):
|
1676
1917
|
The Transformer model to load the LoRA layers into.
|
1677
1918
|
adapter_name (`str`, *optional*):
|
1678
1919
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1679
1920
|
`default_{i}` where i is the total number of adapters being loaded.
|
1921
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
1922
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
1923
|
+
weights.
|
1680
1924
|
"""
|
1681
|
-
|
1925
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
1926
|
+
raise ValueError(
|
1927
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
1928
|
+
)
|
1682
1929
|
|
1930
|
+
# Load the layers corresponding to transformer.
|
1683
1931
|
keys = list(state_dict.keys())
|
1932
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
1933
|
+
if transformer_present:
|
1934
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
1935
|
+
transformer.load_lora_adapter(
|
1936
|
+
state_dict,
|
1937
|
+
network_alphas=network_alphas,
|
1938
|
+
adapter_name=adapter_name,
|
1939
|
+
_pipeline=_pipeline,
|
1940
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1941
|
+
)
|
1684
1942
|
|
1685
|
-
|
1686
|
-
|
1687
|
-
|
1688
|
-
|
1943
|
+
@classmethod
|
1944
|
+
def _load_norm_into_transformer(
|
1945
|
+
cls,
|
1946
|
+
state_dict,
|
1947
|
+
transformer,
|
1948
|
+
prefix=None,
|
1949
|
+
discard_original_layers=False,
|
1950
|
+
) -> Dict[str, torch.Tensor]:
|
1951
|
+
# Remove prefix if present
|
1952
|
+
prefix = prefix or cls.transformer_name
|
1953
|
+
for key in list(state_dict.keys()):
|
1954
|
+
if key.split(".")[0] == prefix:
|
1955
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
1956
|
+
|
1957
|
+
# Find invalid keys
|
1958
|
+
transformer_state_dict = transformer.state_dict()
|
1959
|
+
transformer_keys = set(transformer_state_dict.keys())
|
1960
|
+
state_dict_keys = set(state_dict.keys())
|
1961
|
+
extra_keys = list(state_dict_keys - transformer_keys)
|
1962
|
+
|
1963
|
+
if extra_keys:
|
1964
|
+
logger.warning(
|
1965
|
+
f"Unsupported keys found in state dict when trying to load normalization layers into the transformer. The following keys will be ignored:\n{extra_keys}."
|
1966
|
+
)
|
1689
1967
|
|
1690
|
-
|
1691
|
-
|
1692
|
-
first_key = next(iter(state_dict.keys()))
|
1693
|
-
if "lora_A" not in first_key:
|
1694
|
-
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
1968
|
+
for key in extra_keys:
|
1969
|
+
state_dict.pop(key)
|
1695
1970
|
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1971
|
+
# Save the layers that are going to be overwritten so that unload_lora_weights can work as expected
|
1972
|
+
overwritten_layers_state_dict = {}
|
1973
|
+
if not discard_original_layers:
|
1974
|
+
for key in state_dict.keys():
|
1975
|
+
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
|
1700
1976
|
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1704
|
-
|
1977
|
+
logger.info(
|
1978
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
|
1979
|
+
'as opposed to the LoRA layers that will co-exist separately until the "fuse_lora()" method is called. That is to say, the normalization layers will always be directly '
|
1980
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed. This might also have implications when dealing with multiple LoRAs. "
|
1981
|
+
"If you notice something unexpected, please open an issue: https://github.com/huggingface/diffusers/issues."
|
1982
|
+
)
|
1705
1983
|
|
1706
|
-
|
1707
|
-
|
1708
|
-
|
1709
|
-
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
1984
|
+
# We can't load with strict=True because the current state_dict does not contain all the transformer keys
|
1985
|
+
incompatible_keys = transformer.load_state_dict(state_dict, strict=False)
|
1986
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1710
1987
|
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
else:
|
1718
|
-
lora_config_kwargs.pop("use_dora")
|
1719
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
1720
|
-
|
1721
|
-
# adapter_name
|
1722
|
-
if adapter_name is None:
|
1723
|
-
adapter_name = get_adapter_name(transformer)
|
1724
|
-
|
1725
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
1726
|
-
# otherwise loading LoRA weights will lead to an error
|
1727
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
1728
|
-
|
1729
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
1730
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
1731
|
-
|
1732
|
-
if incompatible_keys is not None:
|
1733
|
-
# check only for unexpected keys
|
1734
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
1735
|
-
if unexpected_keys:
|
1736
|
-
logger.warning(
|
1737
|
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
1738
|
-
f" {unexpected_keys}. "
|
1739
|
-
)
|
1988
|
+
# We shouldn't expect to see the supported norm keys here being present in the unexpected keys.
|
1989
|
+
if unexpected_keys:
|
1990
|
+
if any(norm_key in k for k in unexpected_keys for norm_key in cls._control_lora_supported_norm_keys):
|
1991
|
+
raise ValueError(
|
1992
|
+
f"Found {unexpected_keys} as unexpected keys while trying to load norm layers into the transformer."
|
1993
|
+
)
|
1740
1994
|
|
1741
|
-
|
1742
|
-
if is_model_cpu_offload:
|
1743
|
-
_pipeline.enable_model_cpu_offload()
|
1744
|
-
elif is_sequential_cpu_offload:
|
1745
|
-
_pipeline.enable_sequential_cpu_offload()
|
1746
|
-
# Unsafe code />
|
1995
|
+
return overwritten_layers_state_dict
|
1747
1996
|
|
1748
1997
|
@classmethod
|
1749
1998
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -1756,6 +2005,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1756
2005
|
lora_scale=1.0,
|
1757
2006
|
adapter_name=None,
|
1758
2007
|
_pipeline=None,
|
2008
|
+
low_cpu_mem_usage=False,
|
1759
2009
|
):
|
1760
2010
|
"""
|
1761
2011
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -1765,7 +2015,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1765
2015
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
1766
2016
|
additional `text_encoder` to distinguish between unet lora layers.
|
1767
2017
|
network_alphas (`Dict[str, float]`):
|
1768
|
-
|
2018
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2019
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2020
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
1769
2021
|
text_encoder (`CLIPTextModel`):
|
1770
2022
|
The text encoder model to load the LoRA layers into.
|
1771
2023
|
prefix (`str`):
|
@@ -1776,10 +2028,27 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1776
2028
|
adapter_name (`str`, *optional*):
|
1777
2029
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
1778
2030
|
`default_{i}` where i is the total number of adapters being loaded.
|
2031
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2032
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2033
|
+
weights.
|
1779
2034
|
"""
|
1780
2035
|
if not USE_PEFT_BACKEND:
|
1781
2036
|
raise ValueError("PEFT backend is required for this method.")
|
1782
2037
|
|
2038
|
+
peft_kwargs = {}
|
2039
|
+
if low_cpu_mem_usage:
|
2040
|
+
if not is_peft_version(">=", "0.13.1"):
|
2041
|
+
raise ValueError(
|
2042
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2043
|
+
)
|
2044
|
+
if not is_transformers_version(">", "4.45.2"):
|
2045
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
2046
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
2047
|
+
raise ValueError(
|
2048
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2049
|
+
)
|
2050
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2051
|
+
|
1783
2052
|
from peft import LoraConfig
|
1784
2053
|
|
1785
2054
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -1827,6 +2096,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1827
2096
|
}
|
1828
2097
|
|
1829
2098
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2099
|
+
|
1830
2100
|
if "use_dora" in lora_config_kwargs:
|
1831
2101
|
if lora_config_kwargs["use_dora"]:
|
1832
2102
|
if is_peft_version("<", "0.9.0"):
|
@@ -1836,6 +2106,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1836
2106
|
else:
|
1837
2107
|
if is_peft_version("<", "0.9.0"):
|
1838
2108
|
lora_config_kwargs.pop("use_dora")
|
2109
|
+
|
2110
|
+
if "lora_bias" in lora_config_kwargs:
|
2111
|
+
if lora_config_kwargs["lora_bias"]:
|
2112
|
+
if is_peft_version("<=", "0.13.2"):
|
2113
|
+
raise ValueError(
|
2114
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2115
|
+
)
|
2116
|
+
else:
|
2117
|
+
if is_peft_version("<=", "0.13.2"):
|
2118
|
+
lora_config_kwargs.pop("lora_bias")
|
2119
|
+
|
1839
2120
|
lora_config = LoraConfig(**lora_config_kwargs)
|
1840
2121
|
|
1841
2122
|
# adapter_name
|
@@ -1850,6 +2131,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1850
2131
|
adapter_name=adapter_name,
|
1851
2132
|
adapter_state_dict=text_encoder_lora_state_dict,
|
1852
2133
|
peft_config=lora_config,
|
2134
|
+
**peft_kwargs,
|
1853
2135
|
)
|
1854
2136
|
|
1855
2137
|
# scale LoRA layers with `lora_scale`
|
@@ -1919,7 +2201,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1919
2201
|
safe_serialization=safe_serialization,
|
1920
2202
|
)
|
1921
2203
|
|
1922
|
-
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
1923
2204
|
def fuse_lora(
|
1924
2205
|
self,
|
1925
2206
|
components: List[str] = ["transformer", "text_encoder"],
|
@@ -1959,6 +2240,19 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1959
2240
|
pipeline.fuse_lora(lora_scale=0.7)
|
1960
2241
|
```
|
1961
2242
|
"""
|
2243
|
+
|
2244
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2245
|
+
if (
|
2246
|
+
hasattr(transformer, "_transformer_norm_layers")
|
2247
|
+
and isinstance(transformer._transformer_norm_layers, dict)
|
2248
|
+
and len(transformer._transformer_norm_layers.keys()) > 0
|
2249
|
+
):
|
2250
|
+
logger.info(
|
2251
|
+
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
|
2252
|
+
"as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly "
|
2253
|
+
"fused into the transformer and can only be unfused if `discard_original_layers=True` is passed."
|
2254
|
+
)
|
2255
|
+
|
1962
2256
|
super().fuse_lora(
|
1963
2257
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
1964
2258
|
)
|
@@ -1977,8 +2271,168 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
|
1977
2271
|
Args:
|
1978
2272
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
1979
2273
|
"""
|
2274
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2275
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
2276
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2277
|
+
|
1980
2278
|
super().unfuse_lora(components=components)
|
1981
2279
|
|
2280
|
+
# We override this here account for `_transformer_norm_layers`.
|
2281
|
+
def unload_lora_weights(self):
|
2282
|
+
super().unload_lora_weights()
|
2283
|
+
|
2284
|
+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
2285
|
+
if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers:
|
2286
|
+
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
|
2287
|
+
transformer._transformer_norm_layers = None
|
2288
|
+
|
2289
|
+
@classmethod
|
2290
|
+
def _maybe_expand_transformer_param_shape_or_error_(
|
2291
|
+
cls,
|
2292
|
+
transformer: torch.nn.Module,
|
2293
|
+
lora_state_dict=None,
|
2294
|
+
norm_state_dict=None,
|
2295
|
+
prefix=None,
|
2296
|
+
) -> bool:
|
2297
|
+
"""
|
2298
|
+
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
|
2299
|
+
generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
|
2300
|
+
"""
|
2301
|
+
state_dict = {}
|
2302
|
+
if lora_state_dict is not None:
|
2303
|
+
state_dict.update(lora_state_dict)
|
2304
|
+
if norm_state_dict is not None:
|
2305
|
+
state_dict.update(norm_state_dict)
|
2306
|
+
|
2307
|
+
# Remove prefix if present
|
2308
|
+
prefix = prefix or cls.transformer_name
|
2309
|
+
for key in list(state_dict.keys()):
|
2310
|
+
if key.split(".")[0] == prefix:
|
2311
|
+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
|
2312
|
+
|
2313
|
+
# Expand transformer parameter shapes if they don't match lora
|
2314
|
+
has_param_with_shape_update = False
|
2315
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2316
|
+
for name, module in transformer.named_modules():
|
2317
|
+
if isinstance(module, torch.nn.Linear):
|
2318
|
+
module_weight = module.weight.data
|
2319
|
+
module_bias = module.bias.data if module.bias is not None else None
|
2320
|
+
bias = module_bias is not None
|
2321
|
+
|
2322
|
+
lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name
|
2323
|
+
lora_A_weight_name = f"{lora_base_name}.lora_A.weight"
|
2324
|
+
lora_B_weight_name = f"{lora_base_name}.lora_B.weight"
|
2325
|
+
if lora_A_weight_name not in state_dict:
|
2326
|
+
continue
|
2327
|
+
|
2328
|
+
in_features = state_dict[lora_A_weight_name].shape[1]
|
2329
|
+
out_features = state_dict[lora_B_weight_name].shape[0]
|
2330
|
+
|
2331
|
+
# This means there's no need for an expansion in the params, so we simply skip.
|
2332
|
+
if tuple(module_weight.shape) == (out_features, in_features):
|
2333
|
+
continue
|
2334
|
+
|
2335
|
+
module_out_features, module_in_features = module_weight.shape
|
2336
|
+
debug_message = ""
|
2337
|
+
if in_features > module_in_features:
|
2338
|
+
debug_message += (
|
2339
|
+
f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA '
|
2340
|
+
f"checkpoint contains higher number of features than expected. The number of input_features will be "
|
2341
|
+
f"expanded from {module_in_features} to {in_features}"
|
2342
|
+
)
|
2343
|
+
if out_features > module_out_features:
|
2344
|
+
debug_message += (
|
2345
|
+
", and the number of output features will be "
|
2346
|
+
f"expanded from {module_out_features} to {out_features}."
|
2347
|
+
)
|
2348
|
+
else:
|
2349
|
+
debug_message += "."
|
2350
|
+
if debug_message:
|
2351
|
+
logger.debug(debug_message)
|
2352
|
+
|
2353
|
+
if out_features > module_out_features or in_features > module_in_features:
|
2354
|
+
has_param_with_shape_update = True
|
2355
|
+
parent_module_name, _, current_module_name = name.rpartition(".")
|
2356
|
+
parent_module = transformer.get_submodule(parent_module_name)
|
2357
|
+
|
2358
|
+
with torch.device("meta"):
|
2359
|
+
expanded_module = torch.nn.Linear(
|
2360
|
+
in_features, out_features, bias=bias, dtype=module_weight.dtype
|
2361
|
+
)
|
2362
|
+
# Only weights are expanded and biases are not. This is because only the input dimensions
|
2363
|
+
# are changed while the output dimensions remain the same. The shape of the weight tensor
|
2364
|
+
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
|
2365
|
+
# explains the reason why only weights are expanded.
|
2366
|
+
new_weight = torch.zeros_like(
|
2367
|
+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
|
2368
|
+
)
|
2369
|
+
slices = tuple(slice(0, dim) for dim in module_weight.shape)
|
2370
|
+
new_weight[slices] = module_weight
|
2371
|
+
tmp_state_dict = {"weight": new_weight}
|
2372
|
+
if module_bias is not None:
|
2373
|
+
tmp_state_dict["bias"] = module_bias
|
2374
|
+
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
|
2375
|
+
|
2376
|
+
setattr(parent_module, current_module_name, expanded_module)
|
2377
|
+
|
2378
|
+
del tmp_state_dict
|
2379
|
+
|
2380
|
+
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
|
2381
|
+
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
|
2382
|
+
new_value = int(expanded_module.weight.data.shape[1])
|
2383
|
+
old_value = getattr(transformer.config, attribute_name)
|
2384
|
+
setattr(transformer.config, attribute_name, new_value)
|
2385
|
+
logger.info(
|
2386
|
+
f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}."
|
2387
|
+
)
|
2388
|
+
|
2389
|
+
return has_param_with_shape_update
|
2390
|
+
|
2391
|
+
@classmethod
|
2392
|
+
def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
|
2393
|
+
expanded_module_names = set()
|
2394
|
+
transformer_state_dict = transformer.state_dict()
|
2395
|
+
prefix = f"{cls.transformer_name}."
|
2396
|
+
|
2397
|
+
lora_module_names = [
|
2398
|
+
key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight")
|
2399
|
+
]
|
2400
|
+
lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)]
|
2401
|
+
lora_module_names = sorted(set(lora_module_names))
|
2402
|
+
transformer_module_names = sorted({name for name, _ in transformer.named_modules()})
|
2403
|
+
unexpected_modules = set(lora_module_names) - set(transformer_module_names)
|
2404
|
+
if unexpected_modules:
|
2405
|
+
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
|
2406
|
+
|
2407
|
+
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
|
2408
|
+
for k in lora_module_names:
|
2409
|
+
if k in unexpected_modules:
|
2410
|
+
continue
|
2411
|
+
|
2412
|
+
base_param_name = (
|
2413
|
+
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
|
2414
|
+
)
|
2415
|
+
base_weight_param = transformer_state_dict[base_param_name]
|
2416
|
+
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
|
2417
|
+
|
2418
|
+
if base_weight_param.shape[1] > lora_A_param.shape[1]:
|
2419
|
+
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
|
2420
|
+
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
|
2421
|
+
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
|
2422
|
+
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
|
2423
|
+
expanded_module_names.add(k)
|
2424
|
+
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
|
2425
|
+
raise NotImplementedError(
|
2426
|
+
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
|
2427
|
+
)
|
2428
|
+
|
2429
|
+
if expanded_module_names:
|
2430
|
+
logger.info(
|
2431
|
+
f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new."
|
2432
|
+
)
|
2433
|
+
|
2434
|
+
return lora_state_dict
|
2435
|
+
|
1982
2436
|
|
1983
2437
|
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
|
1984
2438
|
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
|
@@ -1988,7 +2442,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
1988
2442
|
text_encoder_name = TEXT_ENCODER_NAME
|
1989
2443
|
|
1990
2444
|
@classmethod
|
1991
|
-
|
2445
|
+
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
|
2446
|
+
def load_lora_into_transformer(
|
2447
|
+
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2448
|
+
):
|
1992
2449
|
"""
|
1993
2450
|
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
1994
2451
|
|
@@ -1998,78 +2455,35 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
1998
2455
|
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
1999
2456
|
encoder lora layers.
|
2000
2457
|
network_alphas (`Dict[str, float]`):
|
2001
|
-
|
2002
|
-
|
2003
|
-
|
2458
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2459
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2460
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2461
|
+
transformer (`UVit2DModel`):
|
2462
|
+
The Transformer model to load the LoRA layers into.
|
2004
2463
|
adapter_name (`str`, *optional*):
|
2005
2464
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2006
2465
|
`default_{i}` where i is the total number of adapters being loaded.
|
2466
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2467
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2468
|
+
weights.
|
2007
2469
|
"""
|
2008
|
-
if not
|
2009
|
-
raise ValueError(
|
2010
|
-
|
2011
|
-
|
2470
|
+
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
2471
|
+
raise ValueError(
|
2472
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2473
|
+
)
|
2012
2474
|
|
2475
|
+
# Load the layers corresponding to transformer.
|
2013
2476
|
keys = list(state_dict.keys())
|
2014
|
-
|
2015
|
-
|
2016
|
-
|
2017
|
-
|
2018
|
-
|
2019
|
-
|
2020
|
-
|
2021
|
-
|
2022
|
-
|
2023
|
-
|
2024
|
-
}
|
2025
|
-
|
2026
|
-
if len(state_dict.keys()) > 0:
|
2027
|
-
if adapter_name in getattr(transformer, "peft_config", {}):
|
2028
|
-
raise ValueError(
|
2029
|
-
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
2030
|
-
)
|
2031
|
-
|
2032
|
-
rank = {}
|
2033
|
-
for key, val in state_dict.items():
|
2034
|
-
if "lora_B" in key:
|
2035
|
-
rank[key] = val.shape[1]
|
2036
|
-
|
2037
|
-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
2038
|
-
if "use_dora" in lora_config_kwargs:
|
2039
|
-
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
2040
|
-
raise ValueError(
|
2041
|
-
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
2042
|
-
)
|
2043
|
-
else:
|
2044
|
-
lora_config_kwargs.pop("use_dora")
|
2045
|
-
lora_config = LoraConfig(**lora_config_kwargs)
|
2046
|
-
|
2047
|
-
# adapter_name
|
2048
|
-
if adapter_name is None:
|
2049
|
-
adapter_name = get_adapter_name(transformer)
|
2050
|
-
|
2051
|
-
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
2052
|
-
# otherwise loading LoRA weights will lead to an error
|
2053
|
-
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
2054
|
-
|
2055
|
-
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
2056
|
-
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
2057
|
-
|
2058
|
-
if incompatible_keys is not None:
|
2059
|
-
# check only for unexpected keys
|
2060
|
-
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
2061
|
-
if unexpected_keys:
|
2062
|
-
logger.warning(
|
2063
|
-
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
2064
|
-
f" {unexpected_keys}. "
|
2065
|
-
)
|
2066
|
-
|
2067
|
-
# Offload back.
|
2068
|
-
if is_model_cpu_offload:
|
2069
|
-
_pipeline.enable_model_cpu_offload()
|
2070
|
-
elif is_sequential_cpu_offload:
|
2071
|
-
_pipeline.enable_sequential_cpu_offload()
|
2072
|
-
# Unsafe code />
|
2477
|
+
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
|
2478
|
+
if transformer_present:
|
2479
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2480
|
+
transformer.load_lora_adapter(
|
2481
|
+
state_dict,
|
2482
|
+
network_alphas=network_alphas,
|
2483
|
+
adapter_name=adapter_name,
|
2484
|
+
_pipeline=_pipeline,
|
2485
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2486
|
+
)
|
2073
2487
|
|
2074
2488
|
@classmethod
|
2075
2489
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
|
@@ -2082,6 +2496,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2082
2496
|
lora_scale=1.0,
|
2083
2497
|
adapter_name=None,
|
2084
2498
|
_pipeline=None,
|
2499
|
+
low_cpu_mem_usage=False,
|
2085
2500
|
):
|
2086
2501
|
"""
|
2087
2502
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
@@ -2091,7 +2506,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2091
2506
|
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
2092
2507
|
additional `text_encoder` to distinguish between unet lora layers.
|
2093
2508
|
network_alphas (`Dict[str, float]`):
|
2094
|
-
|
2509
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
2510
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
2511
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
2095
2512
|
text_encoder (`CLIPTextModel`):
|
2096
2513
|
The text encoder model to load the LoRA layers into.
|
2097
2514
|
prefix (`str`):
|
@@ -2102,10 +2519,27 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2102
2519
|
adapter_name (`str`, *optional*):
|
2103
2520
|
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2104
2521
|
`default_{i}` where i is the total number of adapters being loaded.
|
2522
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2523
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2524
|
+
weights.
|
2105
2525
|
"""
|
2106
2526
|
if not USE_PEFT_BACKEND:
|
2107
2527
|
raise ValueError("PEFT backend is required for this method.")
|
2108
2528
|
|
2529
|
+
peft_kwargs = {}
|
2530
|
+
if low_cpu_mem_usage:
|
2531
|
+
if not is_peft_version(">=", "0.13.1"):
|
2532
|
+
raise ValueError(
|
2533
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2534
|
+
)
|
2535
|
+
if not is_transformers_version(">", "4.45.2"):
|
2536
|
+
# Note from sayakpaul: It's not in `transformers` stable yet.
|
2537
|
+
# https://github.com/huggingface/transformers/pull/33725/
|
2538
|
+
raise ValueError(
|
2539
|
+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
2540
|
+
)
|
2541
|
+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
2542
|
+
|
2109
2543
|
from peft import LoraConfig
|
2110
2544
|
|
2111
2545
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
@@ -2153,6 +2587,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2153
2587
|
}
|
2154
2588
|
|
2155
2589
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
2590
|
+
|
2156
2591
|
if "use_dora" in lora_config_kwargs:
|
2157
2592
|
if lora_config_kwargs["use_dora"]:
|
2158
2593
|
if is_peft_version("<", "0.9.0"):
|
@@ -2162,6 +2597,17 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2162
2597
|
else:
|
2163
2598
|
if is_peft_version("<", "0.9.0"):
|
2164
2599
|
lora_config_kwargs.pop("use_dora")
|
2600
|
+
|
2601
|
+
if "lora_bias" in lora_config_kwargs:
|
2602
|
+
if lora_config_kwargs["lora_bias"]:
|
2603
|
+
if is_peft_version("<=", "0.13.2"):
|
2604
|
+
raise ValueError(
|
2605
|
+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
2606
|
+
)
|
2607
|
+
else:
|
2608
|
+
if is_peft_version("<=", "0.13.2"):
|
2609
|
+
lora_config_kwargs.pop("lora_bias")
|
2610
|
+
|
2165
2611
|
lora_config = LoraConfig(**lora_config_kwargs)
|
2166
2612
|
|
2167
2613
|
# adapter_name
|
@@ -2176,6 +2622,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2176
2622
|
adapter_name=adapter_name,
|
2177
2623
|
adapter_state_dict=text_encoder_lora_state_dict,
|
2178
2624
|
peft_config=lora_config,
|
2625
|
+
**peft_kwargs,
|
2179
2626
|
)
|
2180
2627
|
|
2181
2628
|
# scale LoRA layers with `lora_scale`
|
@@ -2245,6 +2692,1545 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
|
2245
2692
|
)
|
2246
2693
|
|
2247
2694
|
|
2695
|
+
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
2696
|
+
r"""
|
2697
|
+
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
|
2698
|
+
"""
|
2699
|
+
|
2700
|
+
_lora_loadable_modules = ["transformer"]
|
2701
|
+
transformer_name = TRANSFORMER_NAME
|
2702
|
+
|
2703
|
+
@classmethod
|
2704
|
+
@validate_hf_hub_args
|
2705
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
2706
|
+
def lora_state_dict(
|
2707
|
+
cls,
|
2708
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
2709
|
+
**kwargs,
|
2710
|
+
):
|
2711
|
+
r"""
|
2712
|
+
Return state dict for lora weights and the network alphas.
|
2713
|
+
|
2714
|
+
<Tip warning={true}>
|
2715
|
+
|
2716
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
2717
|
+
|
2718
|
+
This function is experimental and might change in the future.
|
2719
|
+
|
2720
|
+
</Tip>
|
2721
|
+
|
2722
|
+
Parameters:
|
2723
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2724
|
+
Can be either:
|
2725
|
+
|
2726
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
2727
|
+
the Hub.
|
2728
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
2729
|
+
with [`ModelMixin.save_pretrained`].
|
2730
|
+
- A [torch state
|
2731
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
2732
|
+
|
2733
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
2734
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
2735
|
+
is not used.
|
2736
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
2737
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
2738
|
+
cached versions if they exist.
|
2739
|
+
|
2740
|
+
proxies (`Dict[str, str]`, *optional*):
|
2741
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
2742
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
2743
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
2744
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
2745
|
+
won't be downloaded from the Hub.
|
2746
|
+
token (`str` or *bool*, *optional*):
|
2747
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
2748
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
2749
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
2750
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
2751
|
+
allowed by Git.
|
2752
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
2753
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
2754
|
+
|
2755
|
+
"""
|
2756
|
+
# Load the main state dict first which has the LoRA layers for either of
|
2757
|
+
# transformer and text encoder or both.
|
2758
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
2759
|
+
force_download = kwargs.pop("force_download", False)
|
2760
|
+
proxies = kwargs.pop("proxies", None)
|
2761
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
2762
|
+
token = kwargs.pop("token", None)
|
2763
|
+
revision = kwargs.pop("revision", None)
|
2764
|
+
subfolder = kwargs.pop("subfolder", None)
|
2765
|
+
weight_name = kwargs.pop("weight_name", None)
|
2766
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2767
|
+
|
2768
|
+
allow_pickle = False
|
2769
|
+
if use_safetensors is None:
|
2770
|
+
use_safetensors = True
|
2771
|
+
allow_pickle = True
|
2772
|
+
|
2773
|
+
user_agent = {
|
2774
|
+
"file_type": "attn_procs_weights",
|
2775
|
+
"framework": "pytorch",
|
2776
|
+
}
|
2777
|
+
|
2778
|
+
state_dict = _fetch_state_dict(
|
2779
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
2780
|
+
weight_name=weight_name,
|
2781
|
+
use_safetensors=use_safetensors,
|
2782
|
+
local_files_only=local_files_only,
|
2783
|
+
cache_dir=cache_dir,
|
2784
|
+
force_download=force_download,
|
2785
|
+
proxies=proxies,
|
2786
|
+
token=token,
|
2787
|
+
revision=revision,
|
2788
|
+
subfolder=subfolder,
|
2789
|
+
user_agent=user_agent,
|
2790
|
+
allow_pickle=allow_pickle,
|
2791
|
+
)
|
2792
|
+
|
2793
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
2794
|
+
if is_dora_scale_present:
|
2795
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
2796
|
+
logger.warning(warn_msg)
|
2797
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
2798
|
+
|
2799
|
+
return state_dict
|
2800
|
+
|
2801
|
+
def load_lora_weights(
|
2802
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
2803
|
+
):
|
2804
|
+
"""
|
2805
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
2806
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
2807
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
2808
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
2809
|
+
dict is loaded into `self.transformer`.
|
2810
|
+
|
2811
|
+
Parameters:
|
2812
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
2813
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2814
|
+
adapter_name (`str`, *optional*):
|
2815
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2816
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2817
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2818
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2819
|
+
weights.
|
2820
|
+
kwargs (`dict`, *optional*):
|
2821
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
2822
|
+
"""
|
2823
|
+
if not USE_PEFT_BACKEND:
|
2824
|
+
raise ValueError("PEFT backend is required for this method.")
|
2825
|
+
|
2826
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
2827
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2828
|
+
raise ValueError(
|
2829
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2830
|
+
)
|
2831
|
+
|
2832
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
2833
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
2834
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
2835
|
+
|
2836
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
2837
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
2838
|
+
|
2839
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
2840
|
+
if not is_correct_format:
|
2841
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
2842
|
+
|
2843
|
+
self.load_lora_into_transformer(
|
2844
|
+
state_dict,
|
2845
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
2846
|
+
adapter_name=adapter_name,
|
2847
|
+
_pipeline=self,
|
2848
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2849
|
+
)
|
2850
|
+
|
2851
|
+
@classmethod
|
2852
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
2853
|
+
def load_lora_into_transformer(
|
2854
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
2855
|
+
):
|
2856
|
+
"""
|
2857
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
2858
|
+
|
2859
|
+
Parameters:
|
2860
|
+
state_dict (`dict`):
|
2861
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
2862
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
2863
|
+
encoder lora layers.
|
2864
|
+
transformer (`CogVideoXTransformer3DModel`):
|
2865
|
+
The Transformer model to load the LoRA layers into.
|
2866
|
+
adapter_name (`str`, *optional*):
|
2867
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
2868
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
2869
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
2870
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
2871
|
+
weights.
|
2872
|
+
"""
|
2873
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
2874
|
+
raise ValueError(
|
2875
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
2876
|
+
)
|
2877
|
+
|
2878
|
+
# Load the layers corresponding to transformer.
|
2879
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
2880
|
+
transformer.load_lora_adapter(
|
2881
|
+
state_dict,
|
2882
|
+
network_alphas=None,
|
2883
|
+
adapter_name=adapter_name,
|
2884
|
+
_pipeline=_pipeline,
|
2885
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
2886
|
+
)
|
2887
|
+
|
2888
|
+
@classmethod
|
2889
|
+
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
|
2890
|
+
def save_lora_weights(
|
2891
|
+
cls,
|
2892
|
+
save_directory: Union[str, os.PathLike],
|
2893
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
2894
|
+
is_main_process: bool = True,
|
2895
|
+
weight_name: str = None,
|
2896
|
+
save_function: Callable = None,
|
2897
|
+
safe_serialization: bool = True,
|
2898
|
+
):
|
2899
|
+
r"""
|
2900
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
2901
|
+
|
2902
|
+
Arguments:
|
2903
|
+
save_directory (`str` or `os.PathLike`):
|
2904
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
2905
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
2906
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
2907
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
2908
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
2909
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
2910
|
+
process to avoid race conditions.
|
2911
|
+
save_function (`Callable`):
|
2912
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
2913
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
2914
|
+
`DIFFUSERS_SAVE_MODE`.
|
2915
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
2916
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
2917
|
+
"""
|
2918
|
+
state_dict = {}
|
2919
|
+
|
2920
|
+
if not transformer_lora_layers:
|
2921
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
2922
|
+
|
2923
|
+
if transformer_lora_layers:
|
2924
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
2925
|
+
|
2926
|
+
# Save the model
|
2927
|
+
cls.write_lora_layers(
|
2928
|
+
state_dict=state_dict,
|
2929
|
+
save_directory=save_directory,
|
2930
|
+
is_main_process=is_main_process,
|
2931
|
+
weight_name=weight_name,
|
2932
|
+
save_function=save_function,
|
2933
|
+
safe_serialization=safe_serialization,
|
2934
|
+
)
|
2935
|
+
|
2936
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
2937
|
+
def fuse_lora(
|
2938
|
+
self,
|
2939
|
+
components: List[str] = ["transformer", "text_encoder"],
|
2940
|
+
lora_scale: float = 1.0,
|
2941
|
+
safe_fusing: bool = False,
|
2942
|
+
adapter_names: Optional[List[str]] = None,
|
2943
|
+
**kwargs,
|
2944
|
+
):
|
2945
|
+
r"""
|
2946
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
2947
|
+
|
2948
|
+
<Tip warning={true}>
|
2949
|
+
|
2950
|
+
This is an experimental API.
|
2951
|
+
|
2952
|
+
</Tip>
|
2953
|
+
|
2954
|
+
Args:
|
2955
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
2956
|
+
lora_scale (`float`, defaults to 1.0):
|
2957
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
2958
|
+
safe_fusing (`bool`, defaults to `False`):
|
2959
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
2960
|
+
adapter_names (`List[str]`, *optional*):
|
2961
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
2962
|
+
|
2963
|
+
Example:
|
2964
|
+
|
2965
|
+
```py
|
2966
|
+
from diffusers import DiffusionPipeline
|
2967
|
+
import torch
|
2968
|
+
|
2969
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
2970
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
2971
|
+
).to("cuda")
|
2972
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
2973
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
2974
|
+
```
|
2975
|
+
"""
|
2976
|
+
super().fuse_lora(
|
2977
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
2978
|
+
)
|
2979
|
+
|
2980
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
2981
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
2982
|
+
r"""
|
2983
|
+
Reverses the effect of
|
2984
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
2985
|
+
|
2986
|
+
<Tip warning={true}>
|
2987
|
+
|
2988
|
+
This is an experimental API.
|
2989
|
+
|
2990
|
+
</Tip>
|
2991
|
+
|
2992
|
+
Args:
|
2993
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
2994
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
2995
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
2996
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
2997
|
+
LoRA parameters then it won't have any effect.
|
2998
|
+
"""
|
2999
|
+
super().unfuse_lora(components=components)
|
3000
|
+
|
3001
|
+
|
3002
|
+
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
3003
|
+
r"""
|
3004
|
+
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
|
3005
|
+
"""
|
3006
|
+
|
3007
|
+
_lora_loadable_modules = ["transformer"]
|
3008
|
+
transformer_name = TRANSFORMER_NAME
|
3009
|
+
|
3010
|
+
@classmethod
|
3011
|
+
@validate_hf_hub_args
|
3012
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3013
|
+
def lora_state_dict(
|
3014
|
+
cls,
|
3015
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3016
|
+
**kwargs,
|
3017
|
+
):
|
3018
|
+
r"""
|
3019
|
+
Return state dict for lora weights and the network alphas.
|
3020
|
+
|
3021
|
+
<Tip warning={true}>
|
3022
|
+
|
3023
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3024
|
+
|
3025
|
+
This function is experimental and might change in the future.
|
3026
|
+
|
3027
|
+
</Tip>
|
3028
|
+
|
3029
|
+
Parameters:
|
3030
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3031
|
+
Can be either:
|
3032
|
+
|
3033
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3034
|
+
the Hub.
|
3035
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3036
|
+
with [`ModelMixin.save_pretrained`].
|
3037
|
+
- A [torch state
|
3038
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3039
|
+
|
3040
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3041
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3042
|
+
is not used.
|
3043
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3044
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3045
|
+
cached versions if they exist.
|
3046
|
+
|
3047
|
+
proxies (`Dict[str, str]`, *optional*):
|
3048
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3049
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3050
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3051
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3052
|
+
won't be downloaded from the Hub.
|
3053
|
+
token (`str` or *bool*, *optional*):
|
3054
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3055
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3056
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3057
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3058
|
+
allowed by Git.
|
3059
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3060
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3061
|
+
|
3062
|
+
"""
|
3063
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3064
|
+
# transformer and text encoder or both.
|
3065
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3066
|
+
force_download = kwargs.pop("force_download", False)
|
3067
|
+
proxies = kwargs.pop("proxies", None)
|
3068
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3069
|
+
token = kwargs.pop("token", None)
|
3070
|
+
revision = kwargs.pop("revision", None)
|
3071
|
+
subfolder = kwargs.pop("subfolder", None)
|
3072
|
+
weight_name = kwargs.pop("weight_name", None)
|
3073
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3074
|
+
|
3075
|
+
allow_pickle = False
|
3076
|
+
if use_safetensors is None:
|
3077
|
+
use_safetensors = True
|
3078
|
+
allow_pickle = True
|
3079
|
+
|
3080
|
+
user_agent = {
|
3081
|
+
"file_type": "attn_procs_weights",
|
3082
|
+
"framework": "pytorch",
|
3083
|
+
}
|
3084
|
+
|
3085
|
+
state_dict = _fetch_state_dict(
|
3086
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3087
|
+
weight_name=weight_name,
|
3088
|
+
use_safetensors=use_safetensors,
|
3089
|
+
local_files_only=local_files_only,
|
3090
|
+
cache_dir=cache_dir,
|
3091
|
+
force_download=force_download,
|
3092
|
+
proxies=proxies,
|
3093
|
+
token=token,
|
3094
|
+
revision=revision,
|
3095
|
+
subfolder=subfolder,
|
3096
|
+
user_agent=user_agent,
|
3097
|
+
allow_pickle=allow_pickle,
|
3098
|
+
)
|
3099
|
+
|
3100
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3101
|
+
if is_dora_scale_present:
|
3102
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3103
|
+
logger.warning(warn_msg)
|
3104
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3105
|
+
|
3106
|
+
return state_dict
|
3107
|
+
|
3108
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3109
|
+
def load_lora_weights(
|
3110
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3111
|
+
):
|
3112
|
+
"""
|
3113
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3114
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3115
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3116
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3117
|
+
dict is loaded into `self.transformer`.
|
3118
|
+
|
3119
|
+
Parameters:
|
3120
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3121
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3122
|
+
adapter_name (`str`, *optional*):
|
3123
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3124
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3125
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3126
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3127
|
+
weights.
|
3128
|
+
kwargs (`dict`, *optional*):
|
3129
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3130
|
+
"""
|
3131
|
+
if not USE_PEFT_BACKEND:
|
3132
|
+
raise ValueError("PEFT backend is required for this method.")
|
3133
|
+
|
3134
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3135
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3136
|
+
raise ValueError(
|
3137
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3138
|
+
)
|
3139
|
+
|
3140
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3141
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3142
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3143
|
+
|
3144
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3145
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3146
|
+
|
3147
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3148
|
+
if not is_correct_format:
|
3149
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3150
|
+
|
3151
|
+
self.load_lora_into_transformer(
|
3152
|
+
state_dict,
|
3153
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3154
|
+
adapter_name=adapter_name,
|
3155
|
+
_pipeline=self,
|
3156
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3157
|
+
)
|
3158
|
+
|
3159
|
+
@classmethod
|
3160
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
|
3161
|
+
def load_lora_into_transformer(
|
3162
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3163
|
+
):
|
3164
|
+
"""
|
3165
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3166
|
+
|
3167
|
+
Parameters:
|
3168
|
+
state_dict (`dict`):
|
3169
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3170
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3171
|
+
encoder lora layers.
|
3172
|
+
transformer (`MochiTransformer3DModel`):
|
3173
|
+
The Transformer model to load the LoRA layers into.
|
3174
|
+
adapter_name (`str`, *optional*):
|
3175
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3176
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3177
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3178
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3179
|
+
weights.
|
3180
|
+
"""
|
3181
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3182
|
+
raise ValueError(
|
3183
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3184
|
+
)
|
3185
|
+
|
3186
|
+
# Load the layers corresponding to transformer.
|
3187
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3188
|
+
transformer.load_lora_adapter(
|
3189
|
+
state_dict,
|
3190
|
+
network_alphas=None,
|
3191
|
+
adapter_name=adapter_name,
|
3192
|
+
_pipeline=_pipeline,
|
3193
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3194
|
+
)
|
3195
|
+
|
3196
|
+
@classmethod
|
3197
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3198
|
+
def save_lora_weights(
|
3199
|
+
cls,
|
3200
|
+
save_directory: Union[str, os.PathLike],
|
3201
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3202
|
+
is_main_process: bool = True,
|
3203
|
+
weight_name: str = None,
|
3204
|
+
save_function: Callable = None,
|
3205
|
+
safe_serialization: bool = True,
|
3206
|
+
):
|
3207
|
+
r"""
|
3208
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3209
|
+
|
3210
|
+
Arguments:
|
3211
|
+
save_directory (`str` or `os.PathLike`):
|
3212
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3213
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3214
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3215
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3216
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3217
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3218
|
+
process to avoid race conditions.
|
3219
|
+
save_function (`Callable`):
|
3220
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3221
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3222
|
+
`DIFFUSERS_SAVE_MODE`.
|
3223
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3224
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3225
|
+
"""
|
3226
|
+
state_dict = {}
|
3227
|
+
|
3228
|
+
if not transformer_lora_layers:
|
3229
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3230
|
+
|
3231
|
+
if transformer_lora_layers:
|
3232
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3233
|
+
|
3234
|
+
# Save the model
|
3235
|
+
cls.write_lora_layers(
|
3236
|
+
state_dict=state_dict,
|
3237
|
+
save_directory=save_directory,
|
3238
|
+
is_main_process=is_main_process,
|
3239
|
+
weight_name=weight_name,
|
3240
|
+
save_function=save_function,
|
3241
|
+
safe_serialization=safe_serialization,
|
3242
|
+
)
|
3243
|
+
|
3244
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3245
|
+
def fuse_lora(
|
3246
|
+
self,
|
3247
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3248
|
+
lora_scale: float = 1.0,
|
3249
|
+
safe_fusing: bool = False,
|
3250
|
+
adapter_names: Optional[List[str]] = None,
|
3251
|
+
**kwargs,
|
3252
|
+
):
|
3253
|
+
r"""
|
3254
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3255
|
+
|
3256
|
+
<Tip warning={true}>
|
3257
|
+
|
3258
|
+
This is an experimental API.
|
3259
|
+
|
3260
|
+
</Tip>
|
3261
|
+
|
3262
|
+
Args:
|
3263
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3264
|
+
lora_scale (`float`, defaults to 1.0):
|
3265
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3266
|
+
safe_fusing (`bool`, defaults to `False`):
|
3267
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3268
|
+
adapter_names (`List[str]`, *optional*):
|
3269
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3270
|
+
|
3271
|
+
Example:
|
3272
|
+
|
3273
|
+
```py
|
3274
|
+
from diffusers import DiffusionPipeline
|
3275
|
+
import torch
|
3276
|
+
|
3277
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3278
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3279
|
+
).to("cuda")
|
3280
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3281
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3282
|
+
```
|
3283
|
+
"""
|
3284
|
+
super().fuse_lora(
|
3285
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3286
|
+
)
|
3287
|
+
|
3288
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3289
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3290
|
+
r"""
|
3291
|
+
Reverses the effect of
|
3292
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3293
|
+
|
3294
|
+
<Tip warning={true}>
|
3295
|
+
|
3296
|
+
This is an experimental API.
|
3297
|
+
|
3298
|
+
</Tip>
|
3299
|
+
|
3300
|
+
Args:
|
3301
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3302
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3303
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3304
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3305
|
+
LoRA parameters then it won't have any effect.
|
3306
|
+
"""
|
3307
|
+
super().unfuse_lora(components=components)
|
3308
|
+
|
3309
|
+
|
3310
|
+
class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
3311
|
+
r"""
|
3312
|
+
Load LoRA layers into [`LTXVideoTransformer3DModel`]. Specific to [`LTXPipeline`].
|
3313
|
+
"""
|
3314
|
+
|
3315
|
+
_lora_loadable_modules = ["transformer"]
|
3316
|
+
transformer_name = TRANSFORMER_NAME
|
3317
|
+
|
3318
|
+
@classmethod
|
3319
|
+
@validate_hf_hub_args
|
3320
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
3321
|
+
def lora_state_dict(
|
3322
|
+
cls,
|
3323
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3324
|
+
**kwargs,
|
3325
|
+
):
|
3326
|
+
r"""
|
3327
|
+
Return state dict for lora weights and the network alphas.
|
3328
|
+
|
3329
|
+
<Tip warning={true}>
|
3330
|
+
|
3331
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3332
|
+
|
3333
|
+
This function is experimental and might change in the future.
|
3334
|
+
|
3335
|
+
</Tip>
|
3336
|
+
|
3337
|
+
Parameters:
|
3338
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3339
|
+
Can be either:
|
3340
|
+
|
3341
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3342
|
+
the Hub.
|
3343
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3344
|
+
with [`ModelMixin.save_pretrained`].
|
3345
|
+
- A [torch state
|
3346
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3347
|
+
|
3348
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3349
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3350
|
+
is not used.
|
3351
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3352
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3353
|
+
cached versions if they exist.
|
3354
|
+
|
3355
|
+
proxies (`Dict[str, str]`, *optional*):
|
3356
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3357
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3358
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3359
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3360
|
+
won't be downloaded from the Hub.
|
3361
|
+
token (`str` or *bool*, *optional*):
|
3362
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3363
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3364
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3365
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3366
|
+
allowed by Git.
|
3367
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3368
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3369
|
+
|
3370
|
+
"""
|
3371
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3372
|
+
# transformer and text encoder or both.
|
3373
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3374
|
+
force_download = kwargs.pop("force_download", False)
|
3375
|
+
proxies = kwargs.pop("proxies", None)
|
3376
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3377
|
+
token = kwargs.pop("token", None)
|
3378
|
+
revision = kwargs.pop("revision", None)
|
3379
|
+
subfolder = kwargs.pop("subfolder", None)
|
3380
|
+
weight_name = kwargs.pop("weight_name", None)
|
3381
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3382
|
+
|
3383
|
+
allow_pickle = False
|
3384
|
+
if use_safetensors is None:
|
3385
|
+
use_safetensors = True
|
3386
|
+
allow_pickle = True
|
3387
|
+
|
3388
|
+
user_agent = {
|
3389
|
+
"file_type": "attn_procs_weights",
|
3390
|
+
"framework": "pytorch",
|
3391
|
+
}
|
3392
|
+
|
3393
|
+
state_dict = _fetch_state_dict(
|
3394
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3395
|
+
weight_name=weight_name,
|
3396
|
+
use_safetensors=use_safetensors,
|
3397
|
+
local_files_only=local_files_only,
|
3398
|
+
cache_dir=cache_dir,
|
3399
|
+
force_download=force_download,
|
3400
|
+
proxies=proxies,
|
3401
|
+
token=token,
|
3402
|
+
revision=revision,
|
3403
|
+
subfolder=subfolder,
|
3404
|
+
user_agent=user_agent,
|
3405
|
+
allow_pickle=allow_pickle,
|
3406
|
+
)
|
3407
|
+
|
3408
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3409
|
+
if is_dora_scale_present:
|
3410
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3411
|
+
logger.warning(warn_msg)
|
3412
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3413
|
+
|
3414
|
+
return state_dict
|
3415
|
+
|
3416
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3417
|
+
def load_lora_weights(
|
3418
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3419
|
+
):
|
3420
|
+
"""
|
3421
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3422
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3423
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3424
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3425
|
+
dict is loaded into `self.transformer`.
|
3426
|
+
|
3427
|
+
Parameters:
|
3428
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3429
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3430
|
+
adapter_name (`str`, *optional*):
|
3431
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3432
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3433
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3434
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3435
|
+
weights.
|
3436
|
+
kwargs (`dict`, *optional*):
|
3437
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3438
|
+
"""
|
3439
|
+
if not USE_PEFT_BACKEND:
|
3440
|
+
raise ValueError("PEFT backend is required for this method.")
|
3441
|
+
|
3442
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3443
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3444
|
+
raise ValueError(
|
3445
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3446
|
+
)
|
3447
|
+
|
3448
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3449
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3450
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3451
|
+
|
3452
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3453
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3454
|
+
|
3455
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3456
|
+
if not is_correct_format:
|
3457
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3458
|
+
|
3459
|
+
self.load_lora_into_transformer(
|
3460
|
+
state_dict,
|
3461
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3462
|
+
adapter_name=adapter_name,
|
3463
|
+
_pipeline=self,
|
3464
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3465
|
+
)
|
3466
|
+
|
3467
|
+
@classmethod
|
3468
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
|
3469
|
+
def load_lora_into_transformer(
|
3470
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3471
|
+
):
|
3472
|
+
"""
|
3473
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3474
|
+
|
3475
|
+
Parameters:
|
3476
|
+
state_dict (`dict`):
|
3477
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3478
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3479
|
+
encoder lora layers.
|
3480
|
+
transformer (`LTXVideoTransformer3DModel`):
|
3481
|
+
The Transformer model to load the LoRA layers into.
|
3482
|
+
adapter_name (`str`, *optional*):
|
3483
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3484
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3485
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3486
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3487
|
+
weights.
|
3488
|
+
"""
|
3489
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3490
|
+
raise ValueError(
|
3491
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3492
|
+
)
|
3493
|
+
|
3494
|
+
# Load the layers corresponding to transformer.
|
3495
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3496
|
+
transformer.load_lora_adapter(
|
3497
|
+
state_dict,
|
3498
|
+
network_alphas=None,
|
3499
|
+
adapter_name=adapter_name,
|
3500
|
+
_pipeline=_pipeline,
|
3501
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3502
|
+
)
|
3503
|
+
|
3504
|
+
@classmethod
|
3505
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3506
|
+
def save_lora_weights(
|
3507
|
+
cls,
|
3508
|
+
save_directory: Union[str, os.PathLike],
|
3509
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3510
|
+
is_main_process: bool = True,
|
3511
|
+
weight_name: str = None,
|
3512
|
+
save_function: Callable = None,
|
3513
|
+
safe_serialization: bool = True,
|
3514
|
+
):
|
3515
|
+
r"""
|
3516
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3517
|
+
|
3518
|
+
Arguments:
|
3519
|
+
save_directory (`str` or `os.PathLike`):
|
3520
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3521
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3522
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3523
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3524
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3525
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3526
|
+
process to avoid race conditions.
|
3527
|
+
save_function (`Callable`):
|
3528
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3529
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3530
|
+
`DIFFUSERS_SAVE_MODE`.
|
3531
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3532
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3533
|
+
"""
|
3534
|
+
state_dict = {}
|
3535
|
+
|
3536
|
+
if not transformer_lora_layers:
|
3537
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3538
|
+
|
3539
|
+
if transformer_lora_layers:
|
3540
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3541
|
+
|
3542
|
+
# Save the model
|
3543
|
+
cls.write_lora_layers(
|
3544
|
+
state_dict=state_dict,
|
3545
|
+
save_directory=save_directory,
|
3546
|
+
is_main_process=is_main_process,
|
3547
|
+
weight_name=weight_name,
|
3548
|
+
save_function=save_function,
|
3549
|
+
safe_serialization=safe_serialization,
|
3550
|
+
)
|
3551
|
+
|
3552
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3553
|
+
def fuse_lora(
|
3554
|
+
self,
|
3555
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3556
|
+
lora_scale: float = 1.0,
|
3557
|
+
safe_fusing: bool = False,
|
3558
|
+
adapter_names: Optional[List[str]] = None,
|
3559
|
+
**kwargs,
|
3560
|
+
):
|
3561
|
+
r"""
|
3562
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3563
|
+
|
3564
|
+
<Tip warning={true}>
|
3565
|
+
|
3566
|
+
This is an experimental API.
|
3567
|
+
|
3568
|
+
</Tip>
|
3569
|
+
|
3570
|
+
Args:
|
3571
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3572
|
+
lora_scale (`float`, defaults to 1.0):
|
3573
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3574
|
+
safe_fusing (`bool`, defaults to `False`):
|
3575
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3576
|
+
adapter_names (`List[str]`, *optional*):
|
3577
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3578
|
+
|
3579
|
+
Example:
|
3580
|
+
|
3581
|
+
```py
|
3582
|
+
from diffusers import DiffusionPipeline
|
3583
|
+
import torch
|
3584
|
+
|
3585
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3586
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3587
|
+
).to("cuda")
|
3588
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3589
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3590
|
+
```
|
3591
|
+
"""
|
3592
|
+
super().fuse_lora(
|
3593
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3594
|
+
)
|
3595
|
+
|
3596
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3597
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3598
|
+
r"""
|
3599
|
+
Reverses the effect of
|
3600
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3601
|
+
|
3602
|
+
<Tip warning={true}>
|
3603
|
+
|
3604
|
+
This is an experimental API.
|
3605
|
+
|
3606
|
+
</Tip>
|
3607
|
+
|
3608
|
+
Args:
|
3609
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3610
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3611
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3612
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3613
|
+
LoRA parameters then it won't have any effect.
|
3614
|
+
"""
|
3615
|
+
super().unfuse_lora(components=components)
|
3616
|
+
|
3617
|
+
|
3618
|
+
class SanaLoraLoaderMixin(LoraBaseMixin):
|
3619
|
+
r"""
|
3620
|
+
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
3621
|
+
"""
|
3622
|
+
|
3623
|
+
_lora_loadable_modules = ["transformer"]
|
3624
|
+
transformer_name = TRANSFORMER_NAME
|
3625
|
+
|
3626
|
+
@classmethod
|
3627
|
+
@validate_hf_hub_args
|
3628
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3629
|
+
def lora_state_dict(
|
3630
|
+
cls,
|
3631
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3632
|
+
**kwargs,
|
3633
|
+
):
|
3634
|
+
r"""
|
3635
|
+
Return state dict for lora weights and the network alphas.
|
3636
|
+
|
3637
|
+
<Tip warning={true}>
|
3638
|
+
|
3639
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3640
|
+
|
3641
|
+
This function is experimental and might change in the future.
|
3642
|
+
|
3643
|
+
</Tip>
|
3644
|
+
|
3645
|
+
Parameters:
|
3646
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3647
|
+
Can be either:
|
3648
|
+
|
3649
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3650
|
+
the Hub.
|
3651
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3652
|
+
with [`ModelMixin.save_pretrained`].
|
3653
|
+
- A [torch state
|
3654
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3655
|
+
|
3656
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3657
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3658
|
+
is not used.
|
3659
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3660
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3661
|
+
cached versions if they exist.
|
3662
|
+
|
3663
|
+
proxies (`Dict[str, str]`, *optional*):
|
3664
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3665
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3666
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3667
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3668
|
+
won't be downloaded from the Hub.
|
3669
|
+
token (`str` or *bool*, *optional*):
|
3670
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3671
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3672
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3673
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3674
|
+
allowed by Git.
|
3675
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3676
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3677
|
+
|
3678
|
+
"""
|
3679
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3680
|
+
# transformer and text encoder or both.
|
3681
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3682
|
+
force_download = kwargs.pop("force_download", False)
|
3683
|
+
proxies = kwargs.pop("proxies", None)
|
3684
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3685
|
+
token = kwargs.pop("token", None)
|
3686
|
+
revision = kwargs.pop("revision", None)
|
3687
|
+
subfolder = kwargs.pop("subfolder", None)
|
3688
|
+
weight_name = kwargs.pop("weight_name", None)
|
3689
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3690
|
+
|
3691
|
+
allow_pickle = False
|
3692
|
+
if use_safetensors is None:
|
3693
|
+
use_safetensors = True
|
3694
|
+
allow_pickle = True
|
3695
|
+
|
3696
|
+
user_agent = {
|
3697
|
+
"file_type": "attn_procs_weights",
|
3698
|
+
"framework": "pytorch",
|
3699
|
+
}
|
3700
|
+
|
3701
|
+
state_dict = _fetch_state_dict(
|
3702
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
3703
|
+
weight_name=weight_name,
|
3704
|
+
use_safetensors=use_safetensors,
|
3705
|
+
local_files_only=local_files_only,
|
3706
|
+
cache_dir=cache_dir,
|
3707
|
+
force_download=force_download,
|
3708
|
+
proxies=proxies,
|
3709
|
+
token=token,
|
3710
|
+
revision=revision,
|
3711
|
+
subfolder=subfolder,
|
3712
|
+
user_agent=user_agent,
|
3713
|
+
allow_pickle=allow_pickle,
|
3714
|
+
)
|
3715
|
+
|
3716
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
3717
|
+
if is_dora_scale_present:
|
3718
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
3719
|
+
logger.warning(warn_msg)
|
3720
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
3721
|
+
|
3722
|
+
return state_dict
|
3723
|
+
|
3724
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
3725
|
+
def load_lora_weights(
|
3726
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
3727
|
+
):
|
3728
|
+
"""
|
3729
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
3730
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
3731
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
3732
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
3733
|
+
dict is loaded into `self.transformer`.
|
3734
|
+
|
3735
|
+
Parameters:
|
3736
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3737
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3738
|
+
adapter_name (`str`, *optional*):
|
3739
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3740
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3741
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3742
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3743
|
+
weights.
|
3744
|
+
kwargs (`dict`, *optional*):
|
3745
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
3746
|
+
"""
|
3747
|
+
if not USE_PEFT_BACKEND:
|
3748
|
+
raise ValueError("PEFT backend is required for this method.")
|
3749
|
+
|
3750
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
3751
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3752
|
+
raise ValueError(
|
3753
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3754
|
+
)
|
3755
|
+
|
3756
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
3757
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
3758
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
3759
|
+
|
3760
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
3761
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
3762
|
+
|
3763
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
3764
|
+
if not is_correct_format:
|
3765
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
3766
|
+
|
3767
|
+
self.load_lora_into_transformer(
|
3768
|
+
state_dict,
|
3769
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
3770
|
+
adapter_name=adapter_name,
|
3771
|
+
_pipeline=self,
|
3772
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3773
|
+
)
|
3774
|
+
|
3775
|
+
@classmethod
|
3776
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
|
3777
|
+
def load_lora_into_transformer(
|
3778
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
3779
|
+
):
|
3780
|
+
"""
|
3781
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
3782
|
+
|
3783
|
+
Parameters:
|
3784
|
+
state_dict (`dict`):
|
3785
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
3786
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
3787
|
+
encoder lora layers.
|
3788
|
+
transformer (`SanaTransformer2DModel`):
|
3789
|
+
The Transformer model to load the LoRA layers into.
|
3790
|
+
adapter_name (`str`, *optional*):
|
3791
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
3792
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
3793
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
3794
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
3795
|
+
weights.
|
3796
|
+
"""
|
3797
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
3798
|
+
raise ValueError(
|
3799
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
3800
|
+
)
|
3801
|
+
|
3802
|
+
# Load the layers corresponding to transformer.
|
3803
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
3804
|
+
transformer.load_lora_adapter(
|
3805
|
+
state_dict,
|
3806
|
+
network_alphas=None,
|
3807
|
+
adapter_name=adapter_name,
|
3808
|
+
_pipeline=_pipeline,
|
3809
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
3810
|
+
)
|
3811
|
+
|
3812
|
+
@classmethod
|
3813
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
3814
|
+
def save_lora_weights(
|
3815
|
+
cls,
|
3816
|
+
save_directory: Union[str, os.PathLike],
|
3817
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
3818
|
+
is_main_process: bool = True,
|
3819
|
+
weight_name: str = None,
|
3820
|
+
save_function: Callable = None,
|
3821
|
+
safe_serialization: bool = True,
|
3822
|
+
):
|
3823
|
+
r"""
|
3824
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
3825
|
+
|
3826
|
+
Arguments:
|
3827
|
+
save_directory (`str` or `os.PathLike`):
|
3828
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
3829
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
3830
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
3831
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
3832
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
3833
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
3834
|
+
process to avoid race conditions.
|
3835
|
+
save_function (`Callable`):
|
3836
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
3837
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
3838
|
+
`DIFFUSERS_SAVE_MODE`.
|
3839
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
3840
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
3841
|
+
"""
|
3842
|
+
state_dict = {}
|
3843
|
+
|
3844
|
+
if not transformer_lora_layers:
|
3845
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
3846
|
+
|
3847
|
+
if transformer_lora_layers:
|
3848
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
3849
|
+
|
3850
|
+
# Save the model
|
3851
|
+
cls.write_lora_layers(
|
3852
|
+
state_dict=state_dict,
|
3853
|
+
save_directory=save_directory,
|
3854
|
+
is_main_process=is_main_process,
|
3855
|
+
weight_name=weight_name,
|
3856
|
+
save_function=save_function,
|
3857
|
+
safe_serialization=safe_serialization,
|
3858
|
+
)
|
3859
|
+
|
3860
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
3861
|
+
def fuse_lora(
|
3862
|
+
self,
|
3863
|
+
components: List[str] = ["transformer", "text_encoder"],
|
3864
|
+
lora_scale: float = 1.0,
|
3865
|
+
safe_fusing: bool = False,
|
3866
|
+
adapter_names: Optional[List[str]] = None,
|
3867
|
+
**kwargs,
|
3868
|
+
):
|
3869
|
+
r"""
|
3870
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
3871
|
+
|
3872
|
+
<Tip warning={true}>
|
3873
|
+
|
3874
|
+
This is an experimental API.
|
3875
|
+
|
3876
|
+
</Tip>
|
3877
|
+
|
3878
|
+
Args:
|
3879
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
3880
|
+
lora_scale (`float`, defaults to 1.0):
|
3881
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
3882
|
+
safe_fusing (`bool`, defaults to `False`):
|
3883
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
3884
|
+
adapter_names (`List[str]`, *optional*):
|
3885
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
3886
|
+
|
3887
|
+
Example:
|
3888
|
+
|
3889
|
+
```py
|
3890
|
+
from diffusers import DiffusionPipeline
|
3891
|
+
import torch
|
3892
|
+
|
3893
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
3894
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
3895
|
+
).to("cuda")
|
3896
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
3897
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
3898
|
+
```
|
3899
|
+
"""
|
3900
|
+
super().fuse_lora(
|
3901
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
3902
|
+
)
|
3903
|
+
|
3904
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
3905
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
3906
|
+
r"""
|
3907
|
+
Reverses the effect of
|
3908
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
3909
|
+
|
3910
|
+
<Tip warning={true}>
|
3911
|
+
|
3912
|
+
This is an experimental API.
|
3913
|
+
|
3914
|
+
</Tip>
|
3915
|
+
|
3916
|
+
Args:
|
3917
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
3918
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
3919
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
3920
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
3921
|
+
LoRA parameters then it won't have any effect.
|
3922
|
+
"""
|
3923
|
+
super().unfuse_lora(components=components)
|
3924
|
+
|
3925
|
+
|
3926
|
+
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
3927
|
+
r"""
|
3928
|
+
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
3929
|
+
"""
|
3930
|
+
|
3931
|
+
_lora_loadable_modules = ["transformer"]
|
3932
|
+
transformer_name = TRANSFORMER_NAME
|
3933
|
+
|
3934
|
+
@classmethod
|
3935
|
+
@validate_hf_hub_args
|
3936
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
3937
|
+
def lora_state_dict(
|
3938
|
+
cls,
|
3939
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
3940
|
+
**kwargs,
|
3941
|
+
):
|
3942
|
+
r"""
|
3943
|
+
Return state dict for lora weights and the network alphas.
|
3944
|
+
|
3945
|
+
<Tip warning={true}>
|
3946
|
+
|
3947
|
+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
3948
|
+
|
3949
|
+
This function is experimental and might change in the future.
|
3950
|
+
|
3951
|
+
</Tip>
|
3952
|
+
|
3953
|
+
Parameters:
|
3954
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
3955
|
+
Can be either:
|
3956
|
+
|
3957
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
3958
|
+
the Hub.
|
3959
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
3960
|
+
with [`ModelMixin.save_pretrained`].
|
3961
|
+
- A [torch state
|
3962
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
3963
|
+
|
3964
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
3965
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
3966
|
+
is not used.
|
3967
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
3968
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
3969
|
+
cached versions if they exist.
|
3970
|
+
|
3971
|
+
proxies (`Dict[str, str]`, *optional*):
|
3972
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
3973
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
3974
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
3975
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
3976
|
+
won't be downloaded from the Hub.
|
3977
|
+
token (`str` or *bool*, *optional*):
|
3978
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
3979
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
3980
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
3981
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
3982
|
+
allowed by Git.
|
3983
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
3984
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
3985
|
+
|
3986
|
+
"""
|
3987
|
+
# Load the main state dict first which has the LoRA layers for either of
|
3988
|
+
# transformer and text encoder or both.
|
3989
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
3990
|
+
force_download = kwargs.pop("force_download", False)
|
3991
|
+
proxies = kwargs.pop("proxies", None)
|
3992
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
3993
|
+
token = kwargs.pop("token", None)
|
3994
|
+
revision = kwargs.pop("revision", None)
|
3995
|
+
subfolder = kwargs.pop("subfolder", None)
|
3996
|
+
weight_name = kwargs.pop("weight_name", None)
|
3997
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
3998
|
+
|
3999
|
+
allow_pickle = False
|
4000
|
+
if use_safetensors is None:
|
4001
|
+
use_safetensors = True
|
4002
|
+
allow_pickle = True
|
4003
|
+
|
4004
|
+
user_agent = {
|
4005
|
+
"file_type": "attn_procs_weights",
|
4006
|
+
"framework": "pytorch",
|
4007
|
+
}
|
4008
|
+
|
4009
|
+
state_dict = _fetch_state_dict(
|
4010
|
+
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
4011
|
+
weight_name=weight_name,
|
4012
|
+
use_safetensors=use_safetensors,
|
4013
|
+
local_files_only=local_files_only,
|
4014
|
+
cache_dir=cache_dir,
|
4015
|
+
force_download=force_download,
|
4016
|
+
proxies=proxies,
|
4017
|
+
token=token,
|
4018
|
+
revision=revision,
|
4019
|
+
subfolder=subfolder,
|
4020
|
+
user_agent=user_agent,
|
4021
|
+
allow_pickle=allow_pickle,
|
4022
|
+
)
|
4023
|
+
|
4024
|
+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
4025
|
+
if is_dora_scale_present:
|
4026
|
+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
4027
|
+
logger.warning(warn_msg)
|
4028
|
+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
4029
|
+
|
4030
|
+
return state_dict
|
4031
|
+
|
4032
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
4033
|
+
def load_lora_weights(
|
4034
|
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
4035
|
+
):
|
4036
|
+
"""
|
4037
|
+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
4038
|
+
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
4039
|
+
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
4040
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
4041
|
+
dict is loaded into `self.transformer`.
|
4042
|
+
|
4043
|
+
Parameters:
|
4044
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
4045
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4046
|
+
adapter_name (`str`, *optional*):
|
4047
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4048
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4049
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4050
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4051
|
+
weights.
|
4052
|
+
kwargs (`dict`, *optional*):
|
4053
|
+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
4054
|
+
"""
|
4055
|
+
if not USE_PEFT_BACKEND:
|
4056
|
+
raise ValueError("PEFT backend is required for this method.")
|
4057
|
+
|
4058
|
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
4059
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4060
|
+
raise ValueError(
|
4061
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4062
|
+
)
|
4063
|
+
|
4064
|
+
# if a dict is passed, copy it instead of modifying it inplace
|
4065
|
+
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
4066
|
+
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
4067
|
+
|
4068
|
+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
4069
|
+
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
4070
|
+
|
4071
|
+
is_correct_format = all("lora" in key for key in state_dict.keys())
|
4072
|
+
if not is_correct_format:
|
4073
|
+
raise ValueError("Invalid LoRA checkpoint.")
|
4074
|
+
|
4075
|
+
self.load_lora_into_transformer(
|
4076
|
+
state_dict,
|
4077
|
+
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
4078
|
+
adapter_name=adapter_name,
|
4079
|
+
_pipeline=self,
|
4080
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4081
|
+
)
|
4082
|
+
|
4083
|
+
@classmethod
|
4084
|
+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
|
4085
|
+
def load_lora_into_transformer(
|
4086
|
+
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
4087
|
+
):
|
4088
|
+
"""
|
4089
|
+
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
4090
|
+
|
4091
|
+
Parameters:
|
4092
|
+
state_dict (`dict`):
|
4093
|
+
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
4094
|
+
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
4095
|
+
encoder lora layers.
|
4096
|
+
transformer (`HunyuanVideoTransformer3DModel`):
|
4097
|
+
The Transformer model to load the LoRA layers into.
|
4098
|
+
adapter_name (`str`, *optional*):
|
4099
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
4100
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
4101
|
+
low_cpu_mem_usage (`bool`, *optional*):
|
4102
|
+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
4103
|
+
weights.
|
4104
|
+
"""
|
4105
|
+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
4106
|
+
raise ValueError(
|
4107
|
+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
4108
|
+
)
|
4109
|
+
|
4110
|
+
# Load the layers corresponding to transformer.
|
4111
|
+
logger.info(f"Loading {cls.transformer_name}.")
|
4112
|
+
transformer.load_lora_adapter(
|
4113
|
+
state_dict,
|
4114
|
+
network_alphas=None,
|
4115
|
+
adapter_name=adapter_name,
|
4116
|
+
_pipeline=_pipeline,
|
4117
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
4118
|
+
)
|
4119
|
+
|
4120
|
+
@classmethod
|
4121
|
+
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
4122
|
+
def save_lora_weights(
|
4123
|
+
cls,
|
4124
|
+
save_directory: Union[str, os.PathLike],
|
4125
|
+
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
4126
|
+
is_main_process: bool = True,
|
4127
|
+
weight_name: str = None,
|
4128
|
+
save_function: Callable = None,
|
4129
|
+
safe_serialization: bool = True,
|
4130
|
+
):
|
4131
|
+
r"""
|
4132
|
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
4133
|
+
|
4134
|
+
Arguments:
|
4135
|
+
save_directory (`str` or `os.PathLike`):
|
4136
|
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
4137
|
+
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
4138
|
+
State dict of the LoRA layers corresponding to the `transformer`.
|
4139
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
4140
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
4141
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
4142
|
+
process to avoid race conditions.
|
4143
|
+
save_function (`Callable`):
|
4144
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
4145
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
4146
|
+
`DIFFUSERS_SAVE_MODE`.
|
4147
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
4148
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
4149
|
+
"""
|
4150
|
+
state_dict = {}
|
4151
|
+
|
4152
|
+
if not transformer_lora_layers:
|
4153
|
+
raise ValueError("You must pass `transformer_lora_layers`.")
|
4154
|
+
|
4155
|
+
if transformer_lora_layers:
|
4156
|
+
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
4157
|
+
|
4158
|
+
# Save the model
|
4159
|
+
cls.write_lora_layers(
|
4160
|
+
state_dict=state_dict,
|
4161
|
+
save_directory=save_directory,
|
4162
|
+
is_main_process=is_main_process,
|
4163
|
+
weight_name=weight_name,
|
4164
|
+
save_function=save_function,
|
4165
|
+
safe_serialization=safe_serialization,
|
4166
|
+
)
|
4167
|
+
|
4168
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
4169
|
+
def fuse_lora(
|
4170
|
+
self,
|
4171
|
+
components: List[str] = ["transformer", "text_encoder"],
|
4172
|
+
lora_scale: float = 1.0,
|
4173
|
+
safe_fusing: bool = False,
|
4174
|
+
adapter_names: Optional[List[str]] = None,
|
4175
|
+
**kwargs,
|
4176
|
+
):
|
4177
|
+
r"""
|
4178
|
+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
4179
|
+
|
4180
|
+
<Tip warning={true}>
|
4181
|
+
|
4182
|
+
This is an experimental API.
|
4183
|
+
|
4184
|
+
</Tip>
|
4185
|
+
|
4186
|
+
Args:
|
4187
|
+
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
4188
|
+
lora_scale (`float`, defaults to 1.0):
|
4189
|
+
Controls how much to influence the outputs with the LoRA parameters.
|
4190
|
+
safe_fusing (`bool`, defaults to `False`):
|
4191
|
+
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
4192
|
+
adapter_names (`List[str]`, *optional*):
|
4193
|
+
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
4194
|
+
|
4195
|
+
Example:
|
4196
|
+
|
4197
|
+
```py
|
4198
|
+
from diffusers import DiffusionPipeline
|
4199
|
+
import torch
|
4200
|
+
|
4201
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
4202
|
+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
4203
|
+
).to("cuda")
|
4204
|
+
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
4205
|
+
pipeline.fuse_lora(lora_scale=0.7)
|
4206
|
+
```
|
4207
|
+
"""
|
4208
|
+
super().fuse_lora(
|
4209
|
+
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
4210
|
+
)
|
4211
|
+
|
4212
|
+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
4213
|
+
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
4214
|
+
r"""
|
4215
|
+
Reverses the effect of
|
4216
|
+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
4217
|
+
|
4218
|
+
<Tip warning={true}>
|
4219
|
+
|
4220
|
+
This is an experimental API.
|
4221
|
+
|
4222
|
+
</Tip>
|
4223
|
+
|
4224
|
+
Args:
|
4225
|
+
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
4226
|
+
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
4227
|
+
unfuse_text_encoder (`bool`, defaults to `True`):
|
4228
|
+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
4229
|
+
LoRA parameters then it won't have any effect.
|
4230
|
+
"""
|
4231
|
+
super().unfuse_lora(components=components)
|
4232
|
+
|
4233
|
+
|
2248
4234
|
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
2249
4235
|
def __init__(self, *args, **kwargs):
|
2250
4236
|
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|