diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import torch
|
|
18
18
|
from huggingface_hub.utils import validate_hf_hub_args
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
+
from ..models.modeling_utils import load_state_dict
|
21
22
|
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
22
23
|
|
23
24
|
|
@@ -37,7 +38,7 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
|
37
38
|
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
38
39
|
cache_dir = kwargs.pop("cache_dir", None)
|
39
40
|
force_download = kwargs.pop("force_download", False)
|
40
|
-
resume_download = kwargs.pop("resume_download",
|
41
|
+
resume_download = kwargs.pop("resume_download", None)
|
41
42
|
proxies = kwargs.pop("proxies", None)
|
42
43
|
local_files_only = kwargs.pop("local_files_only", None)
|
43
44
|
token = kwargs.pop("token", None)
|
@@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
|
100
101
|
subfolder=subfolder,
|
101
102
|
user_agent=user_agent,
|
102
103
|
)
|
103
|
-
state_dict =
|
104
|
+
state_dict = load_state_dict(model_file)
|
104
105
|
else:
|
105
106
|
state_dict = pretrained_model_name_or_path
|
106
107
|
|
@@ -307,9 +308,9 @@ class TextualInversionLoaderMixin:
|
|
307
308
|
force_download (`bool`, *optional*, defaults to `False`):
|
308
309
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
309
310
|
cached versions if they exist.
|
310
|
-
resume_download
|
311
|
-
|
312
|
-
|
311
|
+
resume_download:
|
312
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
313
|
+
of Diffusers.
|
313
314
|
proxies (`Dict[str, str]`, *optional*):
|
314
315
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
315
316
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -418,15 +419,20 @@ class TextualInversionLoaderMixin:
|
|
418
419
|
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
419
420
|
is_model_cpu_offload = False
|
420
421
|
is_sequential_cpu_offload = False
|
421
|
-
|
422
|
-
|
423
|
-
if
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
422
|
+
if self.hf_device_map is None:
|
423
|
+
for _, component in self.components.items():
|
424
|
+
if isinstance(component, nn.Module):
|
425
|
+
if hasattr(component, "_hf_hook"):
|
426
|
+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
427
|
+
is_sequential_cpu_offload = (
|
428
|
+
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
429
|
+
or hasattr(component._hf_hook, "hooks")
|
430
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
431
|
+
)
|
432
|
+
logger.info(
|
433
|
+
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
434
|
+
)
|
435
|
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
430
436
|
|
431
437
|
# 7.2 save expected device and dtype
|
432
438
|
device = text_encoder.device
|
@@ -486,20 +492,35 @@ class TextualInversionLoaderMixin:
|
|
486
492
|
|
487
493
|
# Example 3: unload from SDXL
|
488
494
|
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
489
|
-
embedding_path = hf_hub_download(
|
495
|
+
embedding_path = hf_hub_download(
|
496
|
+
repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
|
497
|
+
)
|
490
498
|
|
491
499
|
# load embeddings to the text encoders
|
492
500
|
state_dict = load_file(embedding_path)
|
493
501
|
|
494
502
|
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
495
|
-
pipeline.load_textual_inversion(
|
503
|
+
pipeline.load_textual_inversion(
|
504
|
+
state_dict["clip_l"],
|
505
|
+
token=["<s0>", "<s1>"],
|
506
|
+
text_encoder=pipeline.text_encoder,
|
507
|
+
tokenizer=pipeline.tokenizer,
|
508
|
+
)
|
496
509
|
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
497
|
-
pipeline.load_textual_inversion(
|
510
|
+
pipeline.load_textual_inversion(
|
511
|
+
state_dict["clip_g"],
|
512
|
+
token=["<s0>", "<s1>"],
|
513
|
+
text_encoder=pipeline.text_encoder_2,
|
514
|
+
tokenizer=pipeline.tokenizer_2,
|
515
|
+
)
|
498
516
|
|
499
517
|
# Unload explicitly from both text encoders abd tokenizers
|
500
|
-
pipeline.unload_textual_inversion(
|
501
|
-
|
502
|
-
|
518
|
+
pipeline.unload_textual_inversion(
|
519
|
+
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
|
520
|
+
)
|
521
|
+
pipeline.unload_textual_inversion(
|
522
|
+
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
|
523
|
+
)
|
503
524
|
```
|
504
525
|
"""
|
505
526
|
|
diffusers/loaders/unet.py
CHANGED
@@ -27,11 +27,13 @@ from torch import nn
|
|
27
27
|
|
28
28
|
from ..models.embeddings import (
|
29
29
|
ImageProjection,
|
30
|
+
IPAdapterFaceIDImageProjection,
|
31
|
+
IPAdapterFaceIDPlusImageProjection,
|
30
32
|
IPAdapterFullImageProjection,
|
31
33
|
IPAdapterPlusImageProjection,
|
32
34
|
MultiIPAdapterImageProjection,
|
33
35
|
)
|
34
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
36
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
35
37
|
from ..utils import (
|
36
38
|
USE_PEFT_BACKEND,
|
37
39
|
_get_model_file,
|
@@ -42,11 +44,7 @@ from ..utils import (
|
|
42
44
|
set_adapter_layers,
|
43
45
|
set_weights_and_activate_adapters,
|
44
46
|
)
|
45
|
-
from .
|
46
|
-
convert_stable_cascade_unet_single_file_to_diffusers,
|
47
|
-
infer_stable_cascade_single_file_config,
|
48
|
-
load_single_file_model_checkpoint,
|
49
|
-
)
|
47
|
+
from .unet_loader_utils import _maybe_expand_lora_scales
|
50
48
|
from .utils import AttnProcsLayers
|
51
49
|
|
52
50
|
|
@@ -100,9 +98,9 @@ class UNet2DConditionLoadersMixin:
|
|
100
98
|
force_download (`bool`, *optional*, defaults to `False`):
|
101
99
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
102
100
|
cached versions if they exist.
|
103
|
-
resume_download
|
104
|
-
|
105
|
-
|
101
|
+
resume_download:
|
102
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
103
|
+
of Diffusers.
|
106
104
|
proxies (`Dict[str, str]`, *optional*):
|
107
105
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
108
106
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -146,7 +144,7 @@ class UNet2DConditionLoadersMixin:
|
|
146
144
|
|
147
145
|
cache_dir = kwargs.pop("cache_dir", None)
|
148
146
|
force_download = kwargs.pop("force_download", False)
|
149
|
-
resume_download = kwargs.pop("resume_download",
|
147
|
+
resume_download = kwargs.pop("resume_download", None)
|
150
148
|
proxies = kwargs.pop("proxies", None)
|
151
149
|
local_files_only = kwargs.pop("local_files_only", None)
|
152
150
|
token = kwargs.pop("token", None)
|
@@ -214,7 +212,7 @@ class UNet2DConditionLoadersMixin:
|
|
214
212
|
subfolder=subfolder,
|
215
213
|
user_agent=user_agent,
|
216
214
|
)
|
217
|
-
state_dict =
|
215
|
+
state_dict = load_state_dict(model_file)
|
218
216
|
else:
|
219
217
|
state_dict = pretrained_model_name_or_path_or_dict
|
220
218
|
|
@@ -356,7 +354,11 @@ class UNet2DConditionLoadersMixin:
|
|
356
354
|
for _, component in _pipeline.components.items():
|
357
355
|
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
358
356
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
359
|
-
is_sequential_cpu_offload =
|
357
|
+
is_sequential_cpu_offload = (
|
358
|
+
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
359
|
+
or hasattr(component._hf_hook, "hooks")
|
360
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
361
|
+
)
|
360
362
|
|
361
363
|
logger.info(
|
362
364
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
@@ -564,7 +566,7 @@ class UNet2DConditionLoadersMixin:
|
|
564
566
|
def set_adapters(
|
565
567
|
self,
|
566
568
|
adapter_names: Union[List[str], str],
|
567
|
-
weights: Optional[Union[List[float],
|
569
|
+
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
568
570
|
):
|
569
571
|
"""
|
570
572
|
Set the currently active adapters for use in the UNet.
|
@@ -597,9 +599,9 @@ class UNet2DConditionLoadersMixin:
|
|
597
599
|
|
598
600
|
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
599
601
|
|
600
|
-
|
601
|
-
|
602
|
-
|
602
|
+
# Expand weights into a list, one entry per adapter
|
603
|
+
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
|
604
|
+
if not isinstance(weights, list):
|
603
605
|
weights = [weights] * len(adapter_names)
|
604
606
|
|
605
607
|
if len(adapter_names) != len(weights):
|
@@ -607,6 +609,13 @@ class UNet2DConditionLoadersMixin:
|
|
607
609
|
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
608
610
|
)
|
609
611
|
|
612
|
+
# Set None values to default of 1.0
|
613
|
+
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
|
614
|
+
weights = [w if w is not None else 1.0 for w in weights]
|
615
|
+
|
616
|
+
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
|
617
|
+
weights = _maybe_expand_lora_scales(self, weights)
|
618
|
+
|
610
619
|
set_weights_and_activate_adapters(self, adapter_names, weights)
|
611
620
|
|
612
621
|
def disable_lora(self):
|
@@ -748,6 +757,90 @@ class UNet2DConditionLoadersMixin:
|
|
748
757
|
diffusers_name = diffusers_name.replace("proj.3", "norm")
|
749
758
|
updated_state_dict[diffusers_name] = value
|
750
759
|
|
760
|
+
elif "perceiver_resampler.proj_in.weight" in state_dict:
|
761
|
+
# IP-Adapter Face ID Plus
|
762
|
+
id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
|
763
|
+
embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
|
764
|
+
hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
|
765
|
+
output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
|
766
|
+
heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
|
767
|
+
|
768
|
+
with init_context():
|
769
|
+
image_projection = IPAdapterFaceIDPlusImageProjection(
|
770
|
+
embed_dims=embed_dims,
|
771
|
+
output_dims=output_dims,
|
772
|
+
hidden_dims=hidden_dims,
|
773
|
+
heads=heads,
|
774
|
+
id_embeddings_dim=id_embeddings_dim,
|
775
|
+
)
|
776
|
+
|
777
|
+
for key, value in state_dict.items():
|
778
|
+
diffusers_name = key.replace("perceiver_resampler.", "")
|
779
|
+
diffusers_name = diffusers_name.replace("0.to", "attn.to")
|
780
|
+
diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
|
781
|
+
diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
|
782
|
+
diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
|
783
|
+
diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
|
784
|
+
diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
|
785
|
+
diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
|
786
|
+
diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
|
787
|
+
diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
|
788
|
+
diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
|
789
|
+
diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
|
790
|
+
diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
|
791
|
+
diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
|
792
|
+
diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
|
793
|
+
diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
|
794
|
+
diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
|
795
|
+
diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
|
796
|
+
diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
|
797
|
+
diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
|
798
|
+
diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
|
799
|
+
diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
|
800
|
+
|
801
|
+
if "norm1" in diffusers_name:
|
802
|
+
updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
|
803
|
+
elif "norm2" in diffusers_name:
|
804
|
+
updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
|
805
|
+
elif "to_kv" in diffusers_name:
|
806
|
+
v_chunk = value.chunk(2, dim=0)
|
807
|
+
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
808
|
+
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
809
|
+
elif "to_out" in diffusers_name:
|
810
|
+
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
811
|
+
elif "proj.0.weight" == diffusers_name:
|
812
|
+
updated_state_dict["proj.net.0.proj.weight"] = value
|
813
|
+
elif "proj.0.bias" == diffusers_name:
|
814
|
+
updated_state_dict["proj.net.0.proj.bias"] = value
|
815
|
+
elif "proj.2.weight" == diffusers_name:
|
816
|
+
updated_state_dict["proj.net.2.weight"] = value
|
817
|
+
elif "proj.2.bias" == diffusers_name:
|
818
|
+
updated_state_dict["proj.net.2.bias"] = value
|
819
|
+
else:
|
820
|
+
updated_state_dict[diffusers_name] = value
|
821
|
+
|
822
|
+
elif "norm.weight" in state_dict:
|
823
|
+
# IP-Adapter Face ID
|
824
|
+
id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
|
825
|
+
id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
|
826
|
+
multiplier = id_embeddings_dim_out // id_embeddings_dim_in
|
827
|
+
norm_layer = "norm.weight"
|
828
|
+
cross_attention_dim = state_dict[norm_layer].shape[0]
|
829
|
+
num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
|
830
|
+
|
831
|
+
with init_context():
|
832
|
+
image_projection = IPAdapterFaceIDImageProjection(
|
833
|
+
cross_attention_dim=cross_attention_dim,
|
834
|
+
image_embed_dim=id_embeddings_dim_in,
|
835
|
+
mult=multiplier,
|
836
|
+
num_tokens=num_tokens,
|
837
|
+
)
|
838
|
+
|
839
|
+
for key, value in state_dict.items():
|
840
|
+
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
841
|
+
diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
|
842
|
+
updated_state_dict[diffusers_name] = value
|
843
|
+
|
751
844
|
else:
|
752
845
|
# IP-Adapter Plus
|
753
846
|
num_image_text_embeds = state_dict["latents"].shape[1]
|
@@ -839,6 +932,7 @@ class UNet2DConditionLoadersMixin:
|
|
839
932
|
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
840
933
|
)
|
841
934
|
attn_procs[name] = attn_processor_class()
|
935
|
+
|
842
936
|
else:
|
843
937
|
attn_processor_class = (
|
844
938
|
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
@@ -851,6 +945,12 @@ class UNet2DConditionLoadersMixin:
|
|
851
945
|
elif "proj.3.weight" in state_dict["image_proj"]:
|
852
946
|
# IP-Adapter Full Face
|
853
947
|
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
|
948
|
+
elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
|
949
|
+
# IP-Adapter Face ID Plus
|
950
|
+
num_image_text_embeds += [4]
|
951
|
+
elif "norm.weight" in state_dict["image_proj"]:
|
952
|
+
# IP-Adapter Face ID
|
953
|
+
num_image_text_embeds += [4]
|
854
954
|
else:
|
855
955
|
# IP-Adapter Plus
|
856
956
|
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
@@ -902,102 +1002,55 @@ class UNet2DConditionLoadersMixin:
|
|
902
1002
|
|
903
1003
|
self.to(dtype=self.dtype, device=self.device)
|
904
1004
|
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
config = kwargs.pop("config", None)
|
959
|
-
resume_download = kwargs.pop("resume_download", False)
|
960
|
-
force_download = kwargs.pop("force_download", False)
|
961
|
-
proxies = kwargs.pop("proxies", None)
|
962
|
-
token = kwargs.pop("token", None)
|
963
|
-
cache_dir = kwargs.pop("cache_dir", None)
|
964
|
-
local_files_only = kwargs.pop("local_files_only", None)
|
965
|
-
revision = kwargs.pop("revision", None)
|
966
|
-
torch_dtype = kwargs.pop("torch_dtype", None)
|
967
|
-
|
968
|
-
checkpoint = load_single_file_model_checkpoint(
|
969
|
-
pretrained_model_link_or_path,
|
970
|
-
resume_download=resume_download,
|
971
|
-
force_download=force_download,
|
972
|
-
proxies=proxies,
|
973
|
-
token=token,
|
974
|
-
cache_dir=cache_dir,
|
975
|
-
local_files_only=local_files_only,
|
976
|
-
revision=revision,
|
977
|
-
)
|
978
|
-
|
979
|
-
if config is None:
|
980
|
-
config = infer_stable_cascade_single_file_config(checkpoint)
|
981
|
-
model_config = cls.load_config(**config, **kwargs)
|
982
|
-
else:
|
983
|
-
model_config = config
|
984
|
-
|
985
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
986
|
-
with ctx():
|
987
|
-
model = cls.from_config(model_config, **kwargs)
|
988
|
-
|
989
|
-
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
990
|
-
if is_accelerate_available():
|
991
|
-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
992
|
-
if len(unexpected_keys) > 0:
|
993
|
-
logger.warn(
|
994
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
995
|
-
)
|
996
|
-
|
997
|
-
else:
|
998
|
-
model.load_state_dict(diffusers_format_checkpoint)
|
999
|
-
|
1000
|
-
if torch_dtype is not None:
|
1001
|
-
model.to(torch_dtype)
|
1002
|
-
|
1003
|
-
return model
|
1005
|
+
def _load_ip_adapter_loras(self, state_dicts):
|
1006
|
+
lora_dicts = {}
|
1007
|
+
for key_id, name in enumerate(self.attn_processors.keys()):
|
1008
|
+
for i, state_dict in enumerate(state_dicts):
|
1009
|
+
if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
|
1010
|
+
if i not in lora_dicts:
|
1011
|
+
lora_dicts[i] = {}
|
1012
|
+
lora_dicts[i].update(
|
1013
|
+
{
|
1014
|
+
f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
|
1015
|
+
f"{key_id}.to_k_lora.down.weight"
|
1016
|
+
]
|
1017
|
+
}
|
1018
|
+
)
|
1019
|
+
lora_dicts[i].update(
|
1020
|
+
{
|
1021
|
+
f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
|
1022
|
+
f"{key_id}.to_q_lora.down.weight"
|
1023
|
+
]
|
1024
|
+
}
|
1025
|
+
)
|
1026
|
+
lora_dicts[i].update(
|
1027
|
+
{
|
1028
|
+
f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
|
1029
|
+
f"{key_id}.to_v_lora.down.weight"
|
1030
|
+
]
|
1031
|
+
}
|
1032
|
+
)
|
1033
|
+
lora_dicts[i].update(
|
1034
|
+
{
|
1035
|
+
f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
|
1036
|
+
f"{key_id}.to_out_lora.down.weight"
|
1037
|
+
]
|
1038
|
+
}
|
1039
|
+
)
|
1040
|
+
lora_dicts[i].update(
|
1041
|
+
{f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
|
1042
|
+
)
|
1043
|
+
lora_dicts[i].update(
|
1044
|
+
{f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
|
1045
|
+
)
|
1046
|
+
lora_dicts[i].update(
|
1047
|
+
{f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
|
1048
|
+
)
|
1049
|
+
lora_dicts[i].update(
|
1050
|
+
{
|
1051
|
+
f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
|
1052
|
+
f"{key_id}.to_out_lora.up.weight"
|
1053
|
+
]
|
1054
|
+
}
|
1055
|
+
)
|
1056
|
+
return lora_dicts
|
@@ -0,0 +1,163 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import copy
|
15
|
+
from typing import TYPE_CHECKING, Dict, List, Union
|
16
|
+
|
17
|
+
from ..utils import logging
|
18
|
+
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
# import here to avoid circular imports
|
22
|
+
from ..models import UNet2DConditionModel
|
23
|
+
|
24
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25
|
+
|
26
|
+
|
27
|
+
def _translate_into_actual_layer_name(name):
|
28
|
+
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
|
29
|
+
if name == "mid":
|
30
|
+
return "mid_block.attentions.0"
|
31
|
+
|
32
|
+
updown, block, attn = name.split(".")
|
33
|
+
|
34
|
+
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
|
35
|
+
block = block.replace("block_", "")
|
36
|
+
attn = "attentions." + attn
|
37
|
+
|
38
|
+
return ".".join((updown, block, attn))
|
39
|
+
|
40
|
+
|
41
|
+
def _maybe_expand_lora_scales(
|
42
|
+
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
|
43
|
+
):
|
44
|
+
blocks_with_transformer = {
|
45
|
+
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
46
|
+
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
47
|
+
}
|
48
|
+
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
|
49
|
+
|
50
|
+
expanded_weight_scales = [
|
51
|
+
_maybe_expand_lora_scales_for_one_adapter(
|
52
|
+
weight_for_adapter,
|
53
|
+
blocks_with_transformer,
|
54
|
+
transformer_per_block,
|
55
|
+
unet.state_dict(),
|
56
|
+
default_scale=default_scale,
|
57
|
+
)
|
58
|
+
for weight_for_adapter in weight_scales
|
59
|
+
]
|
60
|
+
|
61
|
+
return expanded_weight_scales
|
62
|
+
|
63
|
+
|
64
|
+
def _maybe_expand_lora_scales_for_one_adapter(
|
65
|
+
scales: Union[float, Dict],
|
66
|
+
blocks_with_transformer: Dict[str, int],
|
67
|
+
transformer_per_block: Dict[str, int],
|
68
|
+
state_dict: None,
|
69
|
+
default_scale: float = 1.0,
|
70
|
+
):
|
71
|
+
"""
|
72
|
+
Expands the inputs into a more granular dictionary. See the example below for more details.
|
73
|
+
|
74
|
+
Parameters:
|
75
|
+
scales (`Union[float, Dict]`):
|
76
|
+
Scales dict to expand.
|
77
|
+
blocks_with_transformer (`Dict[str, int]`):
|
78
|
+
Dict with keys 'up' and 'down', showing which blocks have transformer layers
|
79
|
+
transformer_per_block (`Dict[str, int]`):
|
80
|
+
Dict with keys 'up' and 'down', showing how many transformer layers each block has
|
81
|
+
|
82
|
+
E.g. turns
|
83
|
+
```python
|
84
|
+
scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
|
85
|
+
blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
|
86
|
+
transformer_per_block = {"down": 2, "up": 3}
|
87
|
+
```
|
88
|
+
into
|
89
|
+
```python
|
90
|
+
{
|
91
|
+
"down.block_1.0": 2,
|
92
|
+
"down.block_1.1": 2,
|
93
|
+
"down.block_2.0": 2,
|
94
|
+
"down.block_2.1": 2,
|
95
|
+
"mid": 3,
|
96
|
+
"up.block_0.0": 4,
|
97
|
+
"up.block_0.1": 4,
|
98
|
+
"up.block_0.2": 4,
|
99
|
+
"up.block_1.0": 5,
|
100
|
+
"up.block_1.1": 6,
|
101
|
+
"up.block_1.2": 7,
|
102
|
+
}
|
103
|
+
```
|
104
|
+
"""
|
105
|
+
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
|
106
|
+
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
|
107
|
+
|
108
|
+
if sorted(transformer_per_block.keys()) != ["down", "up"]:
|
109
|
+
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
|
110
|
+
|
111
|
+
if not isinstance(scales, dict):
|
112
|
+
# don't expand if scales is a single number
|
113
|
+
return scales
|
114
|
+
|
115
|
+
scales = copy.deepcopy(scales)
|
116
|
+
|
117
|
+
if "mid" not in scales:
|
118
|
+
scales["mid"] = default_scale
|
119
|
+
elif isinstance(scales["mid"], list):
|
120
|
+
if len(scales["mid"]) == 1:
|
121
|
+
scales["mid"] = scales["mid"][0]
|
122
|
+
else:
|
123
|
+
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
|
124
|
+
|
125
|
+
for updown in ["up", "down"]:
|
126
|
+
if updown not in scales:
|
127
|
+
scales[updown] = default_scale
|
128
|
+
|
129
|
+
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
|
130
|
+
if not isinstance(scales[updown], dict):
|
131
|
+
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
|
132
|
+
|
133
|
+
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
|
134
|
+
for i in blocks_with_transformer[updown]:
|
135
|
+
block = f"block_{i}"
|
136
|
+
# set not assigned blocks to default scale
|
137
|
+
if block not in scales[updown]:
|
138
|
+
scales[updown][block] = default_scale
|
139
|
+
if not isinstance(scales[updown][block], list):
|
140
|
+
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
|
141
|
+
elif len(scales[updown][block]) == 1:
|
142
|
+
# a list specifying scale to each masked IP input
|
143
|
+
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
|
144
|
+
elif len(scales[updown][block]) != transformer_per_block[updown]:
|
145
|
+
raise ValueError(
|
146
|
+
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
|
147
|
+
)
|
148
|
+
|
149
|
+
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
|
150
|
+
for i in blocks_with_transformer[updown]:
|
151
|
+
block = f"block_{i}"
|
152
|
+
for tf_idx, value in enumerate(scales[updown][block]):
|
153
|
+
scales[f"{updown}.{block}.{tf_idx}"] = value
|
154
|
+
|
155
|
+
del scales[updown]
|
156
|
+
|
157
|
+
for layer in scales.keys():
|
158
|
+
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
159
|
+
raise ValueError(
|
160
|
+
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
|
161
|
+
)
|
162
|
+
|
163
|
+
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
|