diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +28 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +278 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
- diffusers-0.27.0.dist-info/RECORD +399 -0
- diffusers-0.26.2.dist-info/RECORD +0 -384
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
import inspect
|
14
15
|
from importlib import import_module
|
15
16
|
from typing import Callable, Optional, Union
|
16
17
|
|
@@ -18,10 +19,11 @@ import torch
|
|
18
19
|
import torch.nn.functional as F
|
19
20
|
from torch import nn
|
20
21
|
|
21
|
-
from ..
|
22
|
+
from ..image_processor import IPAdapterMaskProcessor
|
23
|
+
from ..utils import deprecate, logging
|
22
24
|
from ..utils.import_utils import is_xformers_available
|
23
25
|
from ..utils.torch_utils import maybe_allow_in_graph
|
24
|
-
from .lora import
|
26
|
+
from .lora import LoRALinearLayer
|
25
27
|
|
26
28
|
|
27
29
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -114,6 +116,8 @@ class Attention(nn.Module):
|
|
114
116
|
super().__init__()
|
115
117
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
116
118
|
self.query_dim = query_dim
|
119
|
+
self.use_bias = bias
|
120
|
+
self.is_cross_attention = cross_attention_dim is not None
|
117
121
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
118
122
|
self.upcast_attention = upcast_attention
|
119
123
|
self.upcast_softmax = upcast_softmax
|
@@ -177,10 +181,7 @@ class Attention(nn.Module):
|
|
177
181
|
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
178
182
|
)
|
179
183
|
|
180
|
-
|
181
|
-
linear_cls = nn.Linear
|
182
|
-
else:
|
183
|
-
linear_cls = LoRACompatibleLinear
|
184
|
+
linear_cls = nn.Linear
|
184
185
|
|
185
186
|
self.linear_cls = linear_cls
|
186
187
|
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
@@ -509,6 +510,15 @@ class Attention(nn.Module):
|
|
509
510
|
# The `Attention` class can call different attention processors / attention functions
|
510
511
|
# here we simply pass along all tensors to the selected processor class
|
511
512
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
513
|
+
|
514
|
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
515
|
+
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
516
|
+
if len(unused_kwargs) > 0:
|
517
|
+
logger.warning(
|
518
|
+
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
519
|
+
)
|
520
|
+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
521
|
+
|
512
522
|
return self.processor(
|
513
523
|
self,
|
514
524
|
hidden_states,
|
@@ -548,12 +558,16 @@ class Attention(nn.Module):
|
|
548
558
|
`torch.Tensor`: The reshaped tensor.
|
549
559
|
"""
|
550
560
|
head_size = self.heads
|
551
|
-
|
552
|
-
|
561
|
+
if tensor.ndim == 3:
|
562
|
+
batch_size, seq_len, dim = tensor.shape
|
563
|
+
extra_dim = 1
|
564
|
+
else:
|
565
|
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
566
|
+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
553
567
|
tensor = tensor.permute(0, 2, 1, 3)
|
554
568
|
|
555
569
|
if out_dim == 3:
|
556
|
-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
570
|
+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
557
571
|
|
558
572
|
return tensor
|
559
573
|
|
@@ -682,27 +696,32 @@ class Attention(nn.Module):
|
|
682
696
|
|
683
697
|
@torch.no_grad()
|
684
698
|
def fuse_projections(self, fuse=True):
|
685
|
-
is_cross_attention = self.cross_attention_dim != self.query_dim
|
686
699
|
device = self.to_q.weight.data.device
|
687
700
|
dtype = self.to_q.weight.data.dtype
|
688
701
|
|
689
|
-
if not is_cross_attention:
|
702
|
+
if not self.is_cross_attention:
|
690
703
|
# fetch weight matrices.
|
691
704
|
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
692
705
|
in_features = concatenated_weights.shape[1]
|
693
706
|
out_features = concatenated_weights.shape[0]
|
694
707
|
|
695
708
|
# create a new single projection layer and copy over the weights.
|
696
|
-
self.to_qkv = self.linear_cls(in_features, out_features, bias=
|
709
|
+
self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
697
710
|
self.to_qkv.weight.copy_(concatenated_weights)
|
711
|
+
if self.use_bias:
|
712
|
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
713
|
+
self.to_qkv.bias.copy_(concatenated_bias)
|
698
714
|
|
699
715
|
else:
|
700
716
|
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
701
717
|
in_features = concatenated_weights.shape[1]
|
702
718
|
out_features = concatenated_weights.shape[0]
|
703
719
|
|
704
|
-
self.to_kv = self.linear_cls(in_features, out_features, bias=
|
720
|
+
self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
705
721
|
self.to_kv.weight.copy_(concatenated_weights)
|
722
|
+
if self.use_bias:
|
723
|
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
724
|
+
self.to_kv.bias.copy_(concatenated_bias)
|
706
725
|
|
707
726
|
self.fused_projections = fuse
|
708
727
|
|
@@ -719,11 +738,14 @@ class AttnProcessor:
|
|
719
738
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
720
739
|
attention_mask: Optional[torch.FloatTensor] = None,
|
721
740
|
temb: Optional[torch.FloatTensor] = None,
|
722
|
-
|
741
|
+
*args,
|
742
|
+
**kwargs,
|
723
743
|
) -> torch.Tensor:
|
724
|
-
|
744
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
745
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
746
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
725
747
|
|
726
|
-
|
748
|
+
residual = hidden_states
|
727
749
|
|
728
750
|
if attn.spatial_norm is not None:
|
729
751
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -742,15 +764,15 @@ class AttnProcessor:
|
|
742
764
|
if attn.group_norm is not None:
|
743
765
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
744
766
|
|
745
|
-
query = attn.to_q(hidden_states
|
767
|
+
query = attn.to_q(hidden_states)
|
746
768
|
|
747
769
|
if encoder_hidden_states is None:
|
748
770
|
encoder_hidden_states = hidden_states
|
749
771
|
elif attn.norm_cross:
|
750
772
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
751
773
|
|
752
|
-
key = attn.to_k(encoder_hidden_states
|
753
|
-
value = attn.to_v(encoder_hidden_states
|
774
|
+
key = attn.to_k(encoder_hidden_states)
|
775
|
+
value = attn.to_v(encoder_hidden_states)
|
754
776
|
|
755
777
|
query = attn.head_to_batch_dim(query)
|
756
778
|
key = attn.head_to_batch_dim(key)
|
@@ -761,7 +783,7 @@ class AttnProcessor:
|
|
761
783
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
762
784
|
|
763
785
|
# linear proj
|
764
|
-
hidden_states = attn.to_out[0](hidden_states
|
786
|
+
hidden_states = attn.to_out[0](hidden_states)
|
765
787
|
# dropout
|
766
788
|
hidden_states = attn.to_out[1](hidden_states)
|
767
789
|
|
@@ -892,11 +914,14 @@ class AttnAddedKVProcessor:
|
|
892
914
|
hidden_states: torch.FloatTensor,
|
893
915
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
894
916
|
attention_mask: Optional[torch.FloatTensor] = None,
|
895
|
-
|
917
|
+
*args,
|
918
|
+
**kwargs,
|
896
919
|
) -> torch.Tensor:
|
897
|
-
|
920
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
921
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
922
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
898
923
|
|
899
|
-
|
924
|
+
residual = hidden_states
|
900
925
|
|
901
926
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
902
927
|
batch_size, sequence_length, _ = hidden_states.shape
|
@@ -910,17 +935,17 @@ class AttnAddedKVProcessor:
|
|
910
935
|
|
911
936
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
912
937
|
|
913
|
-
query = attn.to_q(hidden_states
|
938
|
+
query = attn.to_q(hidden_states)
|
914
939
|
query = attn.head_to_batch_dim(query)
|
915
940
|
|
916
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states
|
917
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states
|
941
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
942
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
918
943
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
919
944
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
920
945
|
|
921
946
|
if not attn.only_cross_attention:
|
922
|
-
key = attn.to_k(hidden_states
|
923
|
-
value = attn.to_v(hidden_states
|
947
|
+
key = attn.to_k(hidden_states)
|
948
|
+
value = attn.to_v(hidden_states)
|
924
949
|
key = attn.head_to_batch_dim(key)
|
925
950
|
value = attn.head_to_batch_dim(value)
|
926
951
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
@@ -934,7 +959,7 @@ class AttnAddedKVProcessor:
|
|
934
959
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
935
960
|
|
936
961
|
# linear proj
|
937
|
-
hidden_states = attn.to_out[0](hidden_states
|
962
|
+
hidden_states = attn.to_out[0](hidden_states)
|
938
963
|
# dropout
|
939
964
|
hidden_states = attn.to_out[1](hidden_states)
|
940
965
|
|
@@ -962,11 +987,14 @@ class AttnAddedKVProcessor2_0:
|
|
962
987
|
hidden_states: torch.FloatTensor,
|
963
988
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
964
989
|
attention_mask: Optional[torch.FloatTensor] = None,
|
965
|
-
|
990
|
+
*args,
|
991
|
+
**kwargs,
|
966
992
|
) -> torch.Tensor:
|
967
|
-
|
993
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
994
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
995
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
968
996
|
|
969
|
-
|
997
|
+
residual = hidden_states
|
970
998
|
|
971
999
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
972
1000
|
batch_size, sequence_length, _ = hidden_states.shape
|
@@ -980,7 +1008,7 @@ class AttnAddedKVProcessor2_0:
|
|
980
1008
|
|
981
1009
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
982
1010
|
|
983
|
-
query = attn.to_q(hidden_states
|
1011
|
+
query = attn.to_q(hidden_states)
|
984
1012
|
query = attn.head_to_batch_dim(query, out_dim=4)
|
985
1013
|
|
986
1014
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
@@ -989,8 +1017,8 @@ class AttnAddedKVProcessor2_0:
|
|
989
1017
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
990
1018
|
|
991
1019
|
if not attn.only_cross_attention:
|
992
|
-
key = attn.to_k(hidden_states
|
993
|
-
value = attn.to_v(hidden_states
|
1020
|
+
key = attn.to_k(hidden_states)
|
1021
|
+
value = attn.to_v(hidden_states)
|
994
1022
|
key = attn.head_to_batch_dim(key, out_dim=4)
|
995
1023
|
value = attn.head_to_batch_dim(value, out_dim=4)
|
996
1024
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
@@ -1007,7 +1035,7 @@ class AttnAddedKVProcessor2_0:
|
|
1007
1035
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
1008
1036
|
|
1009
1037
|
# linear proj
|
1010
|
-
hidden_states = attn.to_out[0](hidden_states
|
1038
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1011
1039
|
# dropout
|
1012
1040
|
hidden_states = attn.to_out[1](hidden_states)
|
1013
1041
|
|
@@ -1110,11 +1138,14 @@ class XFormersAttnProcessor:
|
|
1110
1138
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1111
1139
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1112
1140
|
temb: Optional[torch.FloatTensor] = None,
|
1113
|
-
|
1141
|
+
*args,
|
1142
|
+
**kwargs,
|
1114
1143
|
) -> torch.FloatTensor:
|
1115
|
-
|
1144
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1145
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1146
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1116
1147
|
|
1117
|
-
|
1148
|
+
residual = hidden_states
|
1118
1149
|
|
1119
1150
|
if attn.spatial_norm is not None:
|
1120
1151
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -1143,15 +1174,15 @@ class XFormersAttnProcessor:
|
|
1143
1174
|
if attn.group_norm is not None:
|
1144
1175
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1145
1176
|
|
1146
|
-
query = attn.to_q(hidden_states
|
1177
|
+
query = attn.to_q(hidden_states)
|
1147
1178
|
|
1148
1179
|
if encoder_hidden_states is None:
|
1149
1180
|
encoder_hidden_states = hidden_states
|
1150
1181
|
elif attn.norm_cross:
|
1151
1182
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1152
1183
|
|
1153
|
-
key = attn.to_k(encoder_hidden_states
|
1154
|
-
value = attn.to_v(encoder_hidden_states
|
1184
|
+
key = attn.to_k(encoder_hidden_states)
|
1185
|
+
value = attn.to_v(encoder_hidden_states)
|
1155
1186
|
|
1156
1187
|
query = attn.head_to_batch_dim(query).contiguous()
|
1157
1188
|
key = attn.head_to_batch_dim(key).contiguous()
|
@@ -1164,7 +1195,7 @@ class XFormersAttnProcessor:
|
|
1164
1195
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1165
1196
|
|
1166
1197
|
# linear proj
|
1167
|
-
hidden_states = attn.to_out[0](hidden_states
|
1198
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1168
1199
|
# dropout
|
1169
1200
|
hidden_states = attn.to_out[1](hidden_states)
|
1170
1201
|
|
@@ -1195,8 +1226,13 @@ class AttnProcessor2_0:
|
|
1195
1226
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1196
1227
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1197
1228
|
temb: Optional[torch.FloatTensor] = None,
|
1198
|
-
|
1229
|
+
*args,
|
1230
|
+
**kwargs,
|
1199
1231
|
) -> torch.FloatTensor:
|
1232
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1233
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1234
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1235
|
+
|
1200
1236
|
residual = hidden_states
|
1201
1237
|
if attn.spatial_norm is not None:
|
1202
1238
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -1220,16 +1256,15 @@ class AttnProcessor2_0:
|
|
1220
1256
|
if attn.group_norm is not None:
|
1221
1257
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1222
1258
|
|
1223
|
-
|
1224
|
-
query = attn.to_q(hidden_states, *args)
|
1259
|
+
query = attn.to_q(hidden_states)
|
1225
1260
|
|
1226
1261
|
if encoder_hidden_states is None:
|
1227
1262
|
encoder_hidden_states = hidden_states
|
1228
1263
|
elif attn.norm_cross:
|
1229
1264
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1230
1265
|
|
1231
|
-
key = attn.to_k(encoder_hidden_states
|
1232
|
-
value = attn.to_v(encoder_hidden_states
|
1266
|
+
key = attn.to_k(encoder_hidden_states)
|
1267
|
+
value = attn.to_v(encoder_hidden_states)
|
1233
1268
|
|
1234
1269
|
inner_dim = key.shape[-1]
|
1235
1270
|
head_dim = inner_dim // attn.heads
|
@@ -1249,7 +1284,7 @@ class AttnProcessor2_0:
|
|
1249
1284
|
hidden_states = hidden_states.to(query.dtype)
|
1250
1285
|
|
1251
1286
|
# linear proj
|
1252
|
-
hidden_states = attn.to_out[0](hidden_states
|
1287
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1253
1288
|
# dropout
|
1254
1289
|
hidden_states = attn.to_out[1](hidden_states)
|
1255
1290
|
|
@@ -1290,8 +1325,13 @@ class FusedAttnProcessor2_0:
|
|
1290
1325
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1291
1326
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1292
1327
|
temb: Optional[torch.FloatTensor] = None,
|
1293
|
-
|
1328
|
+
*args,
|
1329
|
+
**kwargs,
|
1294
1330
|
) -> torch.FloatTensor:
|
1331
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1332
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1333
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1334
|
+
|
1295
1335
|
residual = hidden_states
|
1296
1336
|
if attn.spatial_norm is not None:
|
1297
1337
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
@@ -1315,17 +1355,16 @@ class FusedAttnProcessor2_0:
|
|
1315
1355
|
if attn.group_norm is not None:
|
1316
1356
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1317
1357
|
|
1318
|
-
args = () if USE_PEFT_BACKEND else (scale,)
|
1319
1358
|
if encoder_hidden_states is None:
|
1320
|
-
qkv = attn.to_qkv(hidden_states
|
1359
|
+
qkv = attn.to_qkv(hidden_states)
|
1321
1360
|
split_size = qkv.shape[-1] // 3
|
1322
1361
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1323
1362
|
else:
|
1324
1363
|
if attn.norm_cross:
|
1325
1364
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1326
|
-
query = attn.to_q(hidden_states
|
1365
|
+
query = attn.to_q(hidden_states)
|
1327
1366
|
|
1328
|
-
kv = attn.to_kv(encoder_hidden_states
|
1367
|
+
kv = attn.to_kv(encoder_hidden_states)
|
1329
1368
|
split_size = kv.shape[-1] // 2
|
1330
1369
|
key, value = torch.split(kv, split_size, dim=-1)
|
1331
1370
|
|
@@ -1346,7 +1385,7 @@ class FusedAttnProcessor2_0:
|
|
1346
1385
|
hidden_states = hidden_states.to(query.dtype)
|
1347
1386
|
|
1348
1387
|
# linear proj
|
1349
|
-
hidden_states = attn.to_out[0](hidden_states
|
1388
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1350
1389
|
# dropout
|
1351
1390
|
hidden_states = attn.to_out[1](hidden_states)
|
1352
1391
|
|
@@ -1799,24 +1838,7 @@ class SpatialNorm(nn.Module):
|
|
1799
1838
|
return new_f
|
1800
1839
|
|
1801
1840
|
|
1802
|
-
## Deprecated
|
1803
1841
|
class LoRAAttnProcessor(nn.Module):
|
1804
|
-
r"""
|
1805
|
-
Processor for implementing the LoRA attention mechanism.
|
1806
|
-
|
1807
|
-
Args:
|
1808
|
-
hidden_size (`int`, *optional*):
|
1809
|
-
The hidden size of the attention layer.
|
1810
|
-
cross_attention_dim (`int`, *optional*):
|
1811
|
-
The number of channels in the `encoder_hidden_states`.
|
1812
|
-
rank (`int`, defaults to 4):
|
1813
|
-
The dimension of the LoRA update matrices.
|
1814
|
-
network_alpha (`int`, *optional*):
|
1815
|
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1816
|
-
kwargs (`dict`):
|
1817
|
-
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1818
|
-
"""
|
1819
|
-
|
1820
1842
|
def __init__(
|
1821
1843
|
self,
|
1822
1844
|
hidden_size: int,
|
@@ -1825,6 +1847,9 @@ class LoRAAttnProcessor(nn.Module):
|
|
1825
1847
|
network_alpha: Optional[int] = None,
|
1826
1848
|
**kwargs,
|
1827
1849
|
):
|
1850
|
+
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
1851
|
+
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
|
1852
|
+
|
1828
1853
|
super().__init__()
|
1829
1854
|
|
1830
1855
|
self.hidden_size = hidden_size
|
@@ -1851,7 +1876,7 @@ class LoRAAttnProcessor(nn.Module):
|
|
1851
1876
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1852
1877
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1853
1878
|
|
1854
|
-
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
1879
|
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
1855
1880
|
self_cls_name = self.__class__.__name__
|
1856
1881
|
deprecate(
|
1857
1882
|
self_cls_name,
|
@@ -1869,27 +1894,10 @@ class LoRAAttnProcessor(nn.Module):
|
|
1869
1894
|
|
1870
1895
|
attn._modules.pop("processor")
|
1871
1896
|
attn.processor = AttnProcessor()
|
1872
|
-
return attn.processor(attn, hidden_states,
|
1897
|
+
return attn.processor(attn, hidden_states, **kwargs)
|
1873
1898
|
|
1874
1899
|
|
1875
1900
|
class LoRAAttnProcessor2_0(nn.Module):
|
1876
|
-
r"""
|
1877
|
-
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1878
|
-
attention.
|
1879
|
-
|
1880
|
-
Args:
|
1881
|
-
hidden_size (`int`):
|
1882
|
-
The hidden size of the attention layer.
|
1883
|
-
cross_attention_dim (`int`, *optional*):
|
1884
|
-
The number of channels in the `encoder_hidden_states`.
|
1885
|
-
rank (`int`, defaults to 4):
|
1886
|
-
The dimension of the LoRA update matrices.
|
1887
|
-
network_alpha (`int`, *optional*):
|
1888
|
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1889
|
-
kwargs (`dict`):
|
1890
|
-
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
1891
|
-
"""
|
1892
|
-
|
1893
1901
|
def __init__(
|
1894
1902
|
self,
|
1895
1903
|
hidden_size: int,
|
@@ -1898,6 +1906,9 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|
1898
1906
|
network_alpha: Optional[int] = None,
|
1899
1907
|
**kwargs,
|
1900
1908
|
):
|
1909
|
+
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
1910
|
+
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
|
1911
|
+
|
1901
1912
|
super().__init__()
|
1902
1913
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1903
1914
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
@@ -1926,7 +1937,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|
1926
1937
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1927
1938
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1928
1939
|
|
1929
|
-
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
1940
|
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
1930
1941
|
self_cls_name = self.__class__.__name__
|
1931
1942
|
deprecate(
|
1932
1943
|
self_cls_name,
|
@@ -1944,7 +1955,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|
1944
1955
|
|
1945
1956
|
attn._modules.pop("processor")
|
1946
1957
|
attn.processor = AttnProcessor2_0()
|
1947
|
-
return attn.processor(attn, hidden_states,
|
1958
|
+
return attn.processor(attn, hidden_states, **kwargs)
|
1948
1959
|
|
1949
1960
|
|
1950
1961
|
class LoRAXFormersAttnProcessor(nn.Module):
|
@@ -2005,7 +2016,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
|
2005
2016
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
2006
2017
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
2007
2018
|
|
2008
|
-
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
2019
|
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
2009
2020
|
self_cls_name = self.__class__.__name__
|
2010
2021
|
deprecate(
|
2011
2022
|
self_cls_name,
|
@@ -2023,7 +2034,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
|
2023
2034
|
|
2024
2035
|
attn._modules.pop("processor")
|
2025
2036
|
attn.processor = XFormersAttnProcessor()
|
2026
|
-
return attn.processor(attn, hidden_states,
|
2037
|
+
return attn.processor(attn, hidden_states, **kwargs)
|
2027
2038
|
|
2028
2039
|
|
2029
2040
|
class LoRAAttnAddedKVProcessor(nn.Module):
|
@@ -2064,7 +2075,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|
2064
2075
|
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2065
2076
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2066
2077
|
|
2067
|
-
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
2078
|
+
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
|
2068
2079
|
self_cls_name = self.__class__.__name__
|
2069
2080
|
deprecate(
|
2070
2081
|
self_cls_name,
|
@@ -2082,7 +2093,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|
2082
2093
|
|
2083
2094
|
attn._modules.pop("processor")
|
2084
2095
|
attn.processor = AttnAddedKVProcessor()
|
2085
|
-
return attn.processor(attn, hidden_states,
|
2096
|
+
return attn.processor(attn, hidden_states, **kwargs)
|
2086
2097
|
|
2087
2098
|
|
2088
2099
|
class IPAdapterAttnProcessor(nn.Module):
|
@@ -2125,12 +2136,13 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2125
2136
|
|
2126
2137
|
def __call__(
|
2127
2138
|
self,
|
2128
|
-
attn,
|
2129
|
-
hidden_states,
|
2130
|
-
encoder_hidden_states=None,
|
2131
|
-
attention_mask=None,
|
2132
|
-
temb=None,
|
2133
|
-
scale=1.0,
|
2139
|
+
attn: Attention,
|
2140
|
+
hidden_states: torch.FloatTensor,
|
2141
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2142
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2143
|
+
temb: Optional[torch.FloatTensor] = None,
|
2144
|
+
scale: float = 1.0,
|
2145
|
+
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
2134
2146
|
):
|
2135
2147
|
residual = hidden_states
|
2136
2148
|
|
@@ -2185,9 +2197,22 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2185
2197
|
hidden_states = torch.bmm(attention_probs, value)
|
2186
2198
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
2187
2199
|
|
2200
|
+
if ip_adapter_masks is not None:
|
2201
|
+
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
|
2202
|
+
raise ValueError(
|
2203
|
+
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
|
2204
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
2205
|
+
)
|
2206
|
+
if len(ip_adapter_masks) != len(self.scale):
|
2207
|
+
raise ValueError(
|
2208
|
+
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
|
2209
|
+
)
|
2210
|
+
else:
|
2211
|
+
ip_adapter_masks = [None] * len(self.scale)
|
2212
|
+
|
2188
2213
|
# for ip-adapter
|
2189
|
-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2190
|
-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2214
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
2215
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
2191
2216
|
):
|
2192
2217
|
ip_key = to_k_ip(current_ip_hidden_states)
|
2193
2218
|
ip_value = to_v_ip(current_ip_hidden_states)
|
@@ -2199,6 +2224,15 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2199
2224
|
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
2200
2225
|
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
2201
2226
|
|
2227
|
+
if mask is not None:
|
2228
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
2229
|
+
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
|
2230
|
+
)
|
2231
|
+
|
2232
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
2233
|
+
|
2234
|
+
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
|
2235
|
+
|
2202
2236
|
hidden_states = hidden_states + scale * current_ip_hidden_states
|
2203
2237
|
|
2204
2238
|
# linear proj
|
@@ -2262,12 +2296,13 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2262
2296
|
|
2263
2297
|
def __call__(
|
2264
2298
|
self,
|
2265
|
-
attn,
|
2266
|
-
hidden_states,
|
2267
|
-
encoder_hidden_states=None,
|
2268
|
-
attention_mask=None,
|
2269
|
-
temb=None,
|
2270
|
-
scale=1.0,
|
2299
|
+
attn: Attention,
|
2300
|
+
hidden_states: torch.FloatTensor,
|
2301
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
2302
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2303
|
+
temb: Optional[torch.FloatTensor] = None,
|
2304
|
+
scale: float = 1.0,
|
2305
|
+
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
2271
2306
|
):
|
2272
2307
|
residual = hidden_states
|
2273
2308
|
|
@@ -2336,9 +2371,22 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2336
2371
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2337
2372
|
hidden_states = hidden_states.to(query.dtype)
|
2338
2373
|
|
2374
|
+
if ip_adapter_masks is not None:
|
2375
|
+
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
|
2376
|
+
raise ValueError(
|
2377
|
+
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
|
2378
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
2379
|
+
)
|
2380
|
+
if len(ip_adapter_masks) != len(self.scale):
|
2381
|
+
raise ValueError(
|
2382
|
+
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
|
2383
|
+
)
|
2384
|
+
else:
|
2385
|
+
ip_adapter_masks = [None] * len(self.scale)
|
2386
|
+
|
2339
2387
|
# for ip-adapter
|
2340
|
-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2341
|
-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2388
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
2389
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
2342
2390
|
):
|
2343
2391
|
ip_key = to_k_ip(current_ip_hidden_states)
|
2344
2392
|
ip_value = to_v_ip(current_ip_hidden_states)
|
@@ -2357,6 +2405,15 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2357
2405
|
)
|
2358
2406
|
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
2359
2407
|
|
2408
|
+
if mask is not None:
|
2409
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
2410
|
+
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
|
2411
|
+
)
|
2412
|
+
|
2413
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
2414
|
+
|
2415
|
+
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
|
2416
|
+
|
2360
2417
|
hidden_states = hidden_states + scale * current_ip_hidden_states
|
2361
2418
|
|
2362
2419
|
# linear proj
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -80,6 +80,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|
80
80
|
norm_num_groups: int = 32,
|
81
81
|
sample_size: int = 32,
|
82
82
|
scaling_factor: float = 0.18215,
|
83
|
+
latents_mean: Optional[Tuple[float]] = None,
|
84
|
+
latents_std: Optional[Tuple[float]] = None,
|
83
85
|
force_upcast: float = True,
|
84
86
|
):
|
85
87
|
super().__init__()
|