diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/lora.py
CHANGED
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
import copy
|
14
15
|
import inspect
|
15
16
|
import os
|
16
17
|
from pathlib import Path
|
@@ -25,7 +26,7 @@ from packaging import version
|
|
25
26
|
from torch import nn
|
26
27
|
|
27
28
|
from .. import __version__
|
28
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
29
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
29
30
|
from ..utils import (
|
30
31
|
USE_PEFT_BACKEND,
|
31
32
|
_get_model_file,
|
@@ -36,6 +37,7 @@ from ..utils import (
|
|
36
37
|
get_adapter_name,
|
37
38
|
get_peft_kwargs,
|
38
39
|
is_accelerate_available,
|
40
|
+
is_peft_version,
|
39
41
|
is_transformers_available,
|
40
42
|
logging,
|
41
43
|
recurse_remove_peft_layers,
|
@@ -113,7 +115,7 @@ class LoraLoaderMixin:
|
|
113
115
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
114
116
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
115
117
|
|
116
|
-
is_correct_format = all("lora" in key for key in state_dict.keys())
|
118
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
117
119
|
if not is_correct_format:
|
118
120
|
raise ValueError("Invalid LoRA checkpoint.")
|
119
121
|
|
@@ -174,9 +176,9 @@ class LoraLoaderMixin:
|
|
174
176
|
force_download (`bool`, *optional*, defaults to `False`):
|
175
177
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
176
178
|
cached versions if they exist.
|
177
|
-
resume_download
|
178
|
-
|
179
|
-
|
179
|
+
resume_download:
|
180
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
181
|
+
of Diffusers.
|
180
182
|
proxies (`Dict[str, str]`, *optional*):
|
181
183
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
182
184
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -206,7 +208,7 @@ class LoraLoaderMixin:
|
|
206
208
|
# UNet and text encoder or both.
|
207
209
|
cache_dir = kwargs.pop("cache_dir", None)
|
208
210
|
force_download = kwargs.pop("force_download", False)
|
209
|
-
resume_download = kwargs.pop("resume_download",
|
211
|
+
resume_download = kwargs.pop("resume_download", None)
|
210
212
|
proxies = kwargs.pop("proxies", None)
|
211
213
|
local_files_only = kwargs.pop("local_files_only", None)
|
212
214
|
token = kwargs.pop("token", None)
|
@@ -281,7 +283,7 @@ class LoraLoaderMixin:
|
|
281
283
|
subfolder=subfolder,
|
282
284
|
user_agent=user_agent,
|
283
285
|
)
|
284
|
-
state_dict =
|
286
|
+
state_dict = load_state_dict(model_file)
|
285
287
|
else:
|
286
288
|
state_dict = pretrained_model_name_or_path_or_dict
|
287
289
|
|
@@ -361,13 +363,17 @@ class LoraLoaderMixin:
|
|
361
363
|
is_model_cpu_offload = False
|
362
364
|
is_sequential_cpu_offload = False
|
363
365
|
|
364
|
-
if _pipeline is not None:
|
366
|
+
if _pipeline is not None and _pipeline.hf_device_map is None:
|
365
367
|
for _, component in _pipeline.components.items():
|
366
368
|
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
367
369
|
if not is_model_cpu_offload:
|
368
370
|
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
369
371
|
if not is_sequential_cpu_offload:
|
370
|
-
is_sequential_cpu_offload =
|
372
|
+
is_sequential_cpu_offload = (
|
373
|
+
isinstance(component._hf_hook, AlignDevicesHook)
|
374
|
+
or hasattr(component._hf_hook, "hooks")
|
375
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
376
|
+
)
|
371
377
|
|
372
378
|
logger.info(
|
373
379
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
@@ -451,6 +457,15 @@ class LoraLoaderMixin:
|
|
451
457
|
rank[key] = val.shape[1]
|
452
458
|
|
453
459
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
460
|
+
if "use_dora" in lora_config_kwargs:
|
461
|
+
if lora_config_kwargs["use_dora"]:
|
462
|
+
if is_peft_version("<", "0.9.0"):
|
463
|
+
raise ValueError(
|
464
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
465
|
+
)
|
466
|
+
else:
|
467
|
+
if is_peft_version("<", "0.9.0"):
|
468
|
+
lora_config_kwargs.pop("use_dora")
|
454
469
|
lora_config = LoraConfig(**lora_config_kwargs)
|
455
470
|
|
456
471
|
# adapter_name
|
@@ -572,6 +587,15 @@ class LoraLoaderMixin:
|
|
572
587
|
}
|
573
588
|
|
574
589
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
590
|
+
if "use_dora" in lora_config_kwargs:
|
591
|
+
if lora_config_kwargs["use_dora"]:
|
592
|
+
if is_peft_version("<", "0.9.0"):
|
593
|
+
raise ValueError(
|
594
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
595
|
+
)
|
596
|
+
else:
|
597
|
+
if is_peft_version("<", "0.9.0"):
|
598
|
+
lora_config_kwargs.pop("use_dora")
|
575
599
|
lora_config = LoraConfig(**lora_config_kwargs)
|
576
600
|
|
577
601
|
# adapter_name
|
@@ -654,6 +678,13 @@ class LoraLoaderMixin:
|
|
654
678
|
rank[key] = val.shape[1]
|
655
679
|
|
656
680
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
681
|
+
if "use_dora" in lora_config_kwargs:
|
682
|
+
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
683
|
+
raise ValueError(
|
684
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
685
|
+
)
|
686
|
+
else:
|
687
|
+
lora_config_kwargs.pop("use_dora")
|
657
688
|
lora_config = LoraConfig(**lora_config_kwargs)
|
658
689
|
|
659
690
|
# adapter_name
|
@@ -959,7 +990,7 @@ class LoraLoaderMixin:
|
|
959
990
|
self,
|
960
991
|
adapter_names: Union[List[str], str],
|
961
992
|
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
962
|
-
text_encoder_weights: List[float] = None,
|
993
|
+
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
963
994
|
):
|
964
995
|
"""
|
965
996
|
Sets the adapter layers for the text encoder.
|
@@ -977,15 +1008,20 @@ class LoraLoaderMixin:
|
|
977
1008
|
raise ValueError("PEFT backend is required for this method.")
|
978
1009
|
|
979
1010
|
def process_weights(adapter_names, weights):
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
weights = [weights]
|
1011
|
+
# Expand weights into a list, one entry per adapter
|
1012
|
+
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
1013
|
+
if not isinstance(weights, list):
|
1014
|
+
weights = [weights] * len(adapter_names)
|
984
1015
|
|
985
1016
|
if len(adapter_names) != len(weights):
|
986
1017
|
raise ValueError(
|
987
1018
|
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
988
1019
|
)
|
1020
|
+
|
1021
|
+
# Set None values to default of 1.0
|
1022
|
+
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
1023
|
+
weights = [w if w is not None else 1.0 for w in weights]
|
1024
|
+
|
989
1025
|
return weights
|
990
1026
|
|
991
1027
|
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
@@ -1033,17 +1069,77 @@ class LoraLoaderMixin:
|
|
1033
1069
|
def set_adapters(
|
1034
1070
|
self,
|
1035
1071
|
adapter_names: Union[List[str], str],
|
1036
|
-
adapter_weights: Optional[List[float]] = None,
|
1072
|
+
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
1037
1073
|
):
|
1074
|
+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
1075
|
+
|
1076
|
+
adapter_weights = copy.deepcopy(adapter_weights)
|
1077
|
+
|
1078
|
+
# Expand weights into a list, one entry per adapter
|
1079
|
+
if not isinstance(adapter_weights, list):
|
1080
|
+
adapter_weights = [adapter_weights] * len(adapter_names)
|
1081
|
+
|
1082
|
+
if len(adapter_names) != len(adapter_weights):
|
1083
|
+
raise ValueError(
|
1084
|
+
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
1085
|
+
)
|
1086
|
+
|
1087
|
+
# Decompose weights into weights for unet, text_encoder and text_encoder_2
|
1088
|
+
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
|
1089
|
+
|
1090
|
+
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
1091
|
+
all_adapters = {
|
1092
|
+
adapter for adapters in list_adapters.values() for adapter in adapters
|
1093
|
+
} # eg ["adapter1", "adapter2"]
|
1094
|
+
invert_list_adapters = {
|
1095
|
+
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
1096
|
+
for adapter in all_adapters
|
1097
|
+
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
1098
|
+
|
1099
|
+
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
1100
|
+
if isinstance(weights, dict):
|
1101
|
+
unet_lora_weight = weights.pop("unet", None)
|
1102
|
+
text_encoder_lora_weight = weights.pop("text_encoder", None)
|
1103
|
+
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
|
1104
|
+
|
1105
|
+
if len(weights) > 0:
|
1106
|
+
raise ValueError(
|
1107
|
+
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
|
1111
|
+
logger.warning(
|
1112
|
+
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
|
1113
|
+
)
|
1114
|
+
|
1115
|
+
# warn if adapter doesn't have parts specified by adapter_weights
|
1116
|
+
for part_weight, part_name in zip(
|
1117
|
+
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
|
1118
|
+
["unet", "text_encoder", "text_encoder_2"],
|
1119
|
+
):
|
1120
|
+
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
|
1121
|
+
logger.warning(
|
1122
|
+
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
else:
|
1126
|
+
unet_lora_weight = weights
|
1127
|
+
text_encoder_lora_weight = weights
|
1128
|
+
text_encoder_2_lora_weight = weights
|
1129
|
+
|
1130
|
+
unet_lora_weights.append(unet_lora_weight)
|
1131
|
+
text_encoder_lora_weights.append(text_encoder_lora_weight)
|
1132
|
+
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
|
1133
|
+
|
1038
1134
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
1039
1135
|
# Handle the UNET
|
1040
|
-
unet.set_adapters(adapter_names,
|
1136
|
+
unet.set_adapters(adapter_names, unet_lora_weights)
|
1041
1137
|
|
1042
1138
|
# Handle the Text Encoder
|
1043
1139
|
if hasattr(self, "text_encoder"):
|
1044
|
-
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder,
|
1140
|
+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
|
1045
1141
|
if hasattr(self, "text_encoder_2"):
|
1046
|
-
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2,
|
1142
|
+
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
|
1047
1143
|
|
1048
1144
|
def disable_lora(self):
|
1049
1145
|
if not USE_PEFT_BACKEND:
|
@@ -1175,6 +1271,11 @@ class LoraLoaderMixin:
|
|
1175
1271
|
for adapter_name in adapter_names:
|
1176
1272
|
unet_module.lora_A[adapter_name].to(device)
|
1177
1273
|
unet_module.lora_B[adapter_name].to(device)
|
1274
|
+
# this is a param, not a module, so device placement is not in-place -> re-assign
|
1275
|
+
if hasattr(unet_module, "lora_magnitude_vector") and unet_module.lora_magnitude_vector is not None:
|
1276
|
+
unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
|
1277
|
+
adapter_name
|
1278
|
+
].to(device)
|
1178
1279
|
|
1179
1280
|
# Handle the text encoder
|
1180
1281
|
modules_to_process = []
|
@@ -1191,6 +1292,14 @@ class LoraLoaderMixin:
|
|
1191
1292
|
for adapter_name in adapter_names:
|
1192
1293
|
text_encoder_module.lora_A[adapter_name].to(device)
|
1193
1294
|
text_encoder_module.lora_B[adapter_name].to(device)
|
1295
|
+
# this is a param, not a module, so device placement is not in-place -> re-assign
|
1296
|
+
if (
|
1297
|
+
hasattr(text_encoder_module, "lora_magnitude_vector")
|
1298
|
+
and text_encoder_module.lora_magnitude_vector is not None
|
1299
|
+
):
|
1300
|
+
text_encoder_module.lora_magnitude_vector[
|
1301
|
+
adapter_name
|
1302
|
+
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
|
1194
1303
|
|
1195
1304
|
|
1196
1305
|
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
@@ -1243,7 +1352,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|
1243
1352
|
unet_config=self.unet.config,
|
1244
1353
|
**kwargs,
|
1245
1354
|
)
|
1246
|
-
is_correct_format = all("lora" in key for key in state_dict.keys())
|
1355
|
+
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
1247
1356
|
if not is_correct_format:
|
1248
1357
|
raise ValueError("Invalid LoRA checkpoint.")
|
1249
1358
|
|
@@ -1297,6 +1406,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|
1297
1406
|
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1298
1407
|
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1299
1408
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1409
|
+
text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1410
|
+
State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
|
1411
|
+
encoder LoRA state dict because it comes from 🤗 Transformers.
|
1300
1412
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
1301
1413
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1302
1414
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
@@ -1323,8 +1435,10 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|
1323
1435
|
if unet_lora_layers:
|
1324
1436
|
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
1325
1437
|
|
1326
|
-
if text_encoder_lora_layers
|
1438
|
+
if text_encoder_lora_layers:
|
1327
1439
|
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
1440
|
+
|
1441
|
+
if text_encoder_2_lora_layers:
|
1328
1442
|
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1329
1443
|
|
1330
1444
|
cls.write_lora_layers(
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import re
|
16
16
|
|
17
|
-
from ..utils import logging
|
17
|
+
from ..utils import is_peft_version, logging
|
18
18
|
|
19
19
|
|
20
20
|
logger = logging.get_logger(__name__)
|
@@ -128,6 +128,15 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
128
128
|
te_state_dict = {}
|
129
129
|
te2_state_dict = {}
|
130
130
|
network_alphas = {}
|
131
|
+
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
132
|
+
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
133
|
+
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
134
|
+
|
135
|
+
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
136
|
+
if is_peft_version("<", "0.9.0"):
|
137
|
+
raise ValueError(
|
138
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
139
|
+
)
|
131
140
|
|
132
141
|
# every down weight has a corresponding up weight and potentially an alpha weight
|
133
142
|
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
@@ -198,46 +207,19 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
198
207
|
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
199
208
|
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
200
209
|
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
207
|
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
208
|
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
209
|
-
if "self_attn" in diffusers_name:
|
210
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
211
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
212
|
-
elif "mlp" in diffusers_name:
|
213
|
-
# Be aware that this is the new diffusers convention and the rest of the code might
|
214
|
-
# not utilize it yet.
|
215
|
-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
216
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
217
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
210
|
+
if is_unet_dora_lora:
|
211
|
+
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
212
|
+
unet_state_dict[
|
213
|
+
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
214
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
218
215
|
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
225
|
-
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
226
|
-
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
227
|
-
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
228
|
-
if "self_attn" in diffusers_name:
|
229
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
230
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
231
|
-
elif "mlp" in diffusers_name:
|
232
|
-
# Be aware that this is the new diffusers convention and the rest of the code might
|
233
|
-
# not utilize it yet.
|
234
|
-
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
235
|
-
te_state_dict[diffusers_name] = state_dict.pop(key)
|
236
|
-
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
216
|
+
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
217
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
218
|
+
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
219
|
+
else:
|
220
|
+
key_to_replace = "lora_te2_"
|
237
221
|
|
238
|
-
|
239
|
-
elif lora_name.startswith("lora_te2_"):
|
240
|
-
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
|
222
|
+
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
241
223
|
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
242
224
|
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
243
225
|
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
@@ -245,14 +227,35 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
|
245
227
|
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
246
228
|
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
247
229
|
if "self_attn" in diffusers_name:
|
248
|
-
|
249
|
-
|
230
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
231
|
+
te_state_dict[diffusers_name] = state_dict.pop(key)
|
232
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
233
|
+
else:
|
234
|
+
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
235
|
+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
250
236
|
elif "mlp" in diffusers_name:
|
251
237
|
# Be aware that this is the new diffusers convention and the rest of the code might
|
252
238
|
# not utilize it yet.
|
253
239
|
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
254
|
-
|
255
|
-
|
240
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
241
|
+
te_state_dict[diffusers_name] = state_dict.pop(key)
|
242
|
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
243
|
+
else:
|
244
|
+
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
245
|
+
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
246
|
+
|
247
|
+
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
248
|
+
dora_scale_key_to_replace_te = (
|
249
|
+
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
250
|
+
)
|
251
|
+
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
252
|
+
te_state_dict[
|
253
|
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
254
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
255
|
+
elif lora_name.startswith("lora_te2_"):
|
256
|
+
te2_state_dict[
|
257
|
+
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
258
|
+
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
256
259
|
|
257
260
|
# Rename the alphas so that they can be mapped appropriately.
|
258
261
|
if lora_name_alpha in state_dict:
|
diffusers/loaders/peft.py
CHANGED
@@ -20,7 +20,8 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
|
|
20
20
|
class PeftAdapterMixin:
|
21
21
|
"""
|
22
22
|
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
23
|
-
more details about adapters and injecting them in a transformer-based model, check out the PEFT
|
23
|
+
more details about adapters and injecting them in a transformer-based model, check out the PEFT
|
24
|
+
[documentation](https://huggingface.co/docs/peft/index).
|
24
25
|
|
25
26
|
Install the latest version of PEFT, and use this mixin to:
|
26
27
|
|
@@ -143,8 +144,8 @@ class PeftAdapterMixin:
|
|
143
144
|
|
144
145
|
def enable_adapters(self) -> None:
|
145
146
|
"""
|
146
|
-
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the
|
147
|
-
|
147
|
+
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
|
148
|
+
adapters to enable.
|
148
149
|
|
149
150
|
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
150
151
|
[documentation](https://huggingface.co/docs/peft).
|