diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
import inspect
|
15
|
+
import math
|
15
16
|
from importlib import import_module
|
16
|
-
from typing import Callable, Optional, Union
|
17
|
+
from typing import Callable, List, Optional, Union
|
17
18
|
|
18
19
|
import torch
|
19
20
|
import torch.nn.functional as F
|
@@ -21,13 +22,15 @@ from torch import nn
|
|
21
22
|
|
22
23
|
from ..image_processor import IPAdapterMaskProcessor
|
23
24
|
from ..utils import deprecate, logging
|
24
|
-
from ..utils.import_utils import is_xformers_available
|
25
|
+
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
25
26
|
from ..utils.torch_utils import maybe_allow_in_graph
|
26
27
|
from .lora import LoRALinearLayer
|
27
28
|
|
28
29
|
|
29
30
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
31
|
|
32
|
+
if is_torch_npu_available():
|
33
|
+
import torch_npu
|
31
34
|
|
32
35
|
if is_xformers_available():
|
33
36
|
import xformers
|
@@ -181,25 +184,22 @@ class Attention(nn.Module):
|
|
181
184
|
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
182
185
|
)
|
183
186
|
|
184
|
-
|
185
|
-
|
186
|
-
self.linear_cls = linear_cls
|
187
|
-
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
187
|
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
188
188
|
|
189
189
|
if not self.only_cross_attention:
|
190
190
|
# only relevant for the `AddedKVProcessor` classes
|
191
|
-
self.to_k =
|
192
|
-
self.to_v =
|
191
|
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
192
|
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
193
193
|
else:
|
194
194
|
self.to_k = None
|
195
195
|
self.to_v = None
|
196
196
|
|
197
197
|
if self.added_kv_proj_dim is not None:
|
198
|
-
self.add_k_proj =
|
199
|
-
self.add_v_proj =
|
198
|
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
199
|
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
200
200
|
|
201
201
|
self.to_out = nn.ModuleList([])
|
202
|
-
self.to_out.append(
|
202
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
203
203
|
self.to_out.append(nn.Dropout(dropout))
|
204
204
|
|
205
205
|
# set attention processor
|
@@ -212,6 +212,23 @@ class Attention(nn.Module):
|
|
212
212
|
)
|
213
213
|
self.set_processor(processor)
|
214
214
|
|
215
|
+
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
216
|
+
r"""
|
217
|
+
Set whether to use npu flash attention from `torch_npu` or not.
|
218
|
+
|
219
|
+
"""
|
220
|
+
if use_npu_flash_attention:
|
221
|
+
processor = AttnProcessorNPU()
|
222
|
+
else:
|
223
|
+
# set attention processor
|
224
|
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
225
|
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
226
|
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
227
|
+
processor = (
|
228
|
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
229
|
+
)
|
230
|
+
self.set_processor(processor)
|
231
|
+
|
215
232
|
def set_use_memory_efficient_attention_xformers(
|
216
233
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
217
234
|
) -> None:
|
@@ -424,7 +441,7 @@ class Attention(nn.Module):
|
|
424
441
|
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
425
442
|
is_lora_activated.pop("add_k_proj", None)
|
426
443
|
is_lora_activated.pop("add_v_proj", None)
|
427
|
-
# 2. else it is not
|
444
|
+
# 2. else it is not possible that only some layers have LoRA activated
|
428
445
|
if not all(is_lora_activated.values()):
|
429
446
|
raise ValueError(
|
430
447
|
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
@@ -486,9 +503,9 @@ class Attention(nn.Module):
|
|
486
503
|
|
487
504
|
def forward(
|
488
505
|
self,
|
489
|
-
hidden_states: torch.
|
490
|
-
encoder_hidden_states: Optional[torch.
|
491
|
-
attention_mask: Optional[torch.
|
506
|
+
hidden_states: torch.Tensor,
|
507
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
508
|
+
attention_mask: Optional[torch.Tensor] = None,
|
492
509
|
**cross_attention_kwargs,
|
493
510
|
) -> torch.Tensor:
|
494
511
|
r"""
|
@@ -706,7 +723,7 @@ class Attention(nn.Module):
|
|
706
723
|
out_features = concatenated_weights.shape[0]
|
707
724
|
|
708
725
|
# create a new single projection layer and copy over the weights.
|
709
|
-
self.to_qkv =
|
726
|
+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
710
727
|
self.to_qkv.weight.copy_(concatenated_weights)
|
711
728
|
if self.use_bias:
|
712
729
|
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
@@ -717,7 +734,7 @@ class Attention(nn.Module):
|
|
717
734
|
in_features = concatenated_weights.shape[1]
|
718
735
|
out_features = concatenated_weights.shape[0]
|
719
736
|
|
720
|
-
self.to_kv =
|
737
|
+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
721
738
|
self.to_kv.weight.copy_(concatenated_weights)
|
722
739
|
if self.use_bias:
|
723
740
|
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
@@ -734,10 +751,10 @@ class AttnProcessor:
|
|
734
751
|
def __call__(
|
735
752
|
self,
|
736
753
|
attn: Attention,
|
737
|
-
hidden_states: torch.
|
738
|
-
encoder_hidden_states: Optional[torch.
|
739
|
-
attention_mask: Optional[torch.
|
740
|
-
temb: Optional[torch.
|
754
|
+
hidden_states: torch.Tensor,
|
755
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
756
|
+
attention_mask: Optional[torch.Tensor] = None,
|
757
|
+
temb: Optional[torch.Tensor] = None,
|
741
758
|
*args,
|
742
759
|
**kwargs,
|
743
760
|
) -> torch.Tensor:
|
@@ -846,9 +863,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
|
|
846
863
|
def __call__(
|
847
864
|
self,
|
848
865
|
attn: Attention,
|
849
|
-
hidden_states: torch.
|
850
|
-
encoder_hidden_states: Optional[torch.
|
851
|
-
attention_mask: Optional[torch.
|
866
|
+
hidden_states: torch.Tensor,
|
867
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
868
|
+
attention_mask: Optional[torch.Tensor] = None,
|
852
869
|
) -> torch.Tensor:
|
853
870
|
batch_size, sequence_length, _ = hidden_states.shape
|
854
871
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
@@ -911,9 +928,9 @@ class AttnAddedKVProcessor:
|
|
911
928
|
def __call__(
|
912
929
|
self,
|
913
930
|
attn: Attention,
|
914
|
-
hidden_states: torch.
|
915
|
-
encoder_hidden_states: Optional[torch.
|
916
|
-
attention_mask: Optional[torch.
|
931
|
+
hidden_states: torch.Tensor,
|
932
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
933
|
+
attention_mask: Optional[torch.Tensor] = None,
|
917
934
|
*args,
|
918
935
|
**kwargs,
|
919
936
|
) -> torch.Tensor:
|
@@ -984,9 +1001,9 @@ class AttnAddedKVProcessor2_0:
|
|
984
1001
|
def __call__(
|
985
1002
|
self,
|
986
1003
|
attn: Attention,
|
987
|
-
hidden_states: torch.
|
988
|
-
encoder_hidden_states: Optional[torch.
|
989
|
-
attention_mask: Optional[torch.
|
1004
|
+
hidden_states: torch.Tensor,
|
1005
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1006
|
+
attention_mask: Optional[torch.Tensor] = None,
|
990
1007
|
*args,
|
991
1008
|
**kwargs,
|
992
1009
|
) -> torch.Tensor:
|
@@ -1063,9 +1080,9 @@ class XFormersAttnAddedKVProcessor:
|
|
1063
1080
|
def __call__(
|
1064
1081
|
self,
|
1065
1082
|
attn: Attention,
|
1066
|
-
hidden_states: torch.
|
1067
|
-
encoder_hidden_states: Optional[torch.
|
1068
|
-
attention_mask: Optional[torch.
|
1083
|
+
hidden_states: torch.Tensor,
|
1084
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1085
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1069
1086
|
) -> torch.Tensor:
|
1070
1087
|
residual = hidden_states
|
1071
1088
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
@@ -1134,13 +1151,13 @@ class XFormersAttnProcessor:
|
|
1134
1151
|
def __call__(
|
1135
1152
|
self,
|
1136
1153
|
attn: Attention,
|
1137
|
-
hidden_states: torch.
|
1138
|
-
encoder_hidden_states: Optional[torch.
|
1139
|
-
attention_mask: Optional[torch.
|
1140
|
-
temb: Optional[torch.
|
1154
|
+
hidden_states: torch.Tensor,
|
1155
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1156
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1157
|
+
temb: Optional[torch.Tensor] = None,
|
1141
1158
|
*args,
|
1142
1159
|
**kwargs,
|
1143
|
-
) -> torch.
|
1160
|
+
) -> torch.Tensor:
|
1144
1161
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1145
1162
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1146
1163
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1210,6 +1227,116 @@ class XFormersAttnProcessor:
|
|
1210
1227
|
return hidden_states
|
1211
1228
|
|
1212
1229
|
|
1230
|
+
class AttnProcessorNPU:
|
1231
|
+
|
1232
|
+
r"""
|
1233
|
+
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
1234
|
+
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
1235
|
+
not significant.
|
1236
|
+
|
1237
|
+
"""
|
1238
|
+
|
1239
|
+
def __init__(self):
|
1240
|
+
if not is_torch_npu_available():
|
1241
|
+
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
|
1242
|
+
|
1243
|
+
def __call__(
|
1244
|
+
self,
|
1245
|
+
attn: Attention,
|
1246
|
+
hidden_states: torch.Tensor,
|
1247
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1248
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1249
|
+
temb: Optional[torch.Tensor] = None,
|
1250
|
+
*args,
|
1251
|
+
**kwargs,
|
1252
|
+
) -> torch.Tensor:
|
1253
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1254
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1255
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1256
|
+
|
1257
|
+
residual = hidden_states
|
1258
|
+
if attn.spatial_norm is not None:
|
1259
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1260
|
+
|
1261
|
+
input_ndim = hidden_states.ndim
|
1262
|
+
|
1263
|
+
if input_ndim == 4:
|
1264
|
+
batch_size, channel, height, width = hidden_states.shape
|
1265
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1266
|
+
|
1267
|
+
batch_size, sequence_length, _ = (
|
1268
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1269
|
+
)
|
1270
|
+
|
1271
|
+
if attention_mask is not None:
|
1272
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1273
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1274
|
+
# (batch, heads, source_length, target_length)
|
1275
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1276
|
+
|
1277
|
+
if attn.group_norm is not None:
|
1278
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1279
|
+
|
1280
|
+
query = attn.to_q(hidden_states)
|
1281
|
+
|
1282
|
+
if encoder_hidden_states is None:
|
1283
|
+
encoder_hidden_states = hidden_states
|
1284
|
+
elif attn.norm_cross:
|
1285
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1286
|
+
|
1287
|
+
key = attn.to_k(encoder_hidden_states)
|
1288
|
+
value = attn.to_v(encoder_hidden_states)
|
1289
|
+
|
1290
|
+
inner_dim = key.shape[-1]
|
1291
|
+
head_dim = inner_dim // attn.heads
|
1292
|
+
|
1293
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1294
|
+
|
1295
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1296
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1297
|
+
|
1298
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1299
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
1300
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
1301
|
+
query,
|
1302
|
+
key,
|
1303
|
+
value,
|
1304
|
+
attn.heads,
|
1305
|
+
input_layout="BNSD",
|
1306
|
+
pse=None,
|
1307
|
+
atten_mask=attention_mask,
|
1308
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
1309
|
+
pre_tockens=65536,
|
1310
|
+
next_tockens=65536,
|
1311
|
+
keep_prob=1.0,
|
1312
|
+
sync=False,
|
1313
|
+
inner_precise=0,
|
1314
|
+
)[0]
|
1315
|
+
else:
|
1316
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1317
|
+
hidden_states = F.scaled_dot_product_attention(
|
1318
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1319
|
+
)
|
1320
|
+
|
1321
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1322
|
+
hidden_states = hidden_states.to(query.dtype)
|
1323
|
+
|
1324
|
+
# linear proj
|
1325
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1326
|
+
# dropout
|
1327
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1328
|
+
|
1329
|
+
if input_ndim == 4:
|
1330
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1331
|
+
|
1332
|
+
if attn.residual_connection:
|
1333
|
+
hidden_states = hidden_states + residual
|
1334
|
+
|
1335
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1336
|
+
|
1337
|
+
return hidden_states
|
1338
|
+
|
1339
|
+
|
1213
1340
|
class AttnProcessor2_0:
|
1214
1341
|
r"""
|
1215
1342
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
@@ -1222,13 +1349,13 @@ class AttnProcessor2_0:
|
|
1222
1349
|
def __call__(
|
1223
1350
|
self,
|
1224
1351
|
attn: Attention,
|
1225
|
-
hidden_states: torch.
|
1226
|
-
encoder_hidden_states: Optional[torch.
|
1227
|
-
attention_mask: Optional[torch.
|
1228
|
-
temb: Optional[torch.
|
1352
|
+
hidden_states: torch.Tensor,
|
1353
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1354
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1355
|
+
temb: Optional[torch.Tensor] = None,
|
1229
1356
|
*args,
|
1230
1357
|
**kwargs,
|
1231
|
-
) -> torch.
|
1358
|
+
) -> torch.Tensor:
|
1232
1359
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1233
1360
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1234
1361
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1301,9 +1428,9 @@ class AttnProcessor2_0:
|
|
1301
1428
|
|
1302
1429
|
class FusedAttnProcessor2_0:
|
1303
1430
|
r"""
|
1304
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1305
|
-
|
1306
|
-
|
1431
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
1432
|
+
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
|
1433
|
+
For cross-attention modules, key and value projection matrices are fused.
|
1307
1434
|
|
1308
1435
|
<Tip warning={true}>
|
1309
1436
|
|
@@ -1321,13 +1448,13 @@ class FusedAttnProcessor2_0:
|
|
1321
1448
|
def __call__(
|
1322
1449
|
self,
|
1323
1450
|
attn: Attention,
|
1324
|
-
hidden_states: torch.
|
1325
|
-
encoder_hidden_states: Optional[torch.
|
1326
|
-
attention_mask: Optional[torch.
|
1327
|
-
temb: Optional[torch.
|
1451
|
+
hidden_states: torch.Tensor,
|
1452
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1453
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1454
|
+
temb: Optional[torch.Tensor] = None,
|
1328
1455
|
*args,
|
1329
1456
|
**kwargs,
|
1330
|
-
) -> torch.
|
1457
|
+
) -> torch.Tensor:
|
1331
1458
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1332
1459
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1333
1460
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1454,10 +1581,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
|
1454
1581
|
def __call__(
|
1455
1582
|
self,
|
1456
1583
|
attn: Attention,
|
1457
|
-
hidden_states: torch.
|
1458
|
-
encoder_hidden_states: Optional[torch.
|
1459
|
-
attention_mask: Optional[torch.
|
1460
|
-
) -> torch.
|
1584
|
+
hidden_states: torch.Tensor,
|
1585
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1586
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1587
|
+
) -> torch.Tensor:
|
1461
1588
|
batch_size, sequence_length, _ = (
|
1462
1589
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1463
1590
|
)
|
@@ -1565,10 +1692,10 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
|
|
1565
1692
|
def __call__(
|
1566
1693
|
self,
|
1567
1694
|
attn: Attention,
|
1568
|
-
hidden_states: torch.
|
1569
|
-
encoder_hidden_states: Optional[torch.
|
1570
|
-
attention_mask: Optional[torch.
|
1571
|
-
) -> torch.
|
1695
|
+
hidden_states: torch.Tensor,
|
1696
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1697
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1698
|
+
) -> torch.Tensor:
|
1572
1699
|
batch_size, sequence_length, _ = hidden_states.shape
|
1573
1700
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1574
1701
|
if self.train_q_out:
|
@@ -1646,10 +1773,10 @@ class SlicedAttnProcessor:
|
|
1646
1773
|
def __call__(
|
1647
1774
|
self,
|
1648
1775
|
attn: Attention,
|
1649
|
-
hidden_states: torch.
|
1650
|
-
encoder_hidden_states: Optional[torch.
|
1651
|
-
attention_mask: Optional[torch.
|
1652
|
-
) -> torch.
|
1776
|
+
hidden_states: torch.Tensor,
|
1777
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1778
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1779
|
+
) -> torch.Tensor:
|
1653
1780
|
residual = hidden_states
|
1654
1781
|
|
1655
1782
|
input_ndim = hidden_states.ndim
|
@@ -1733,11 +1860,11 @@ class SlicedAttnAddedKVProcessor:
|
|
1733
1860
|
def __call__(
|
1734
1861
|
self,
|
1735
1862
|
attn: "Attention",
|
1736
|
-
hidden_states: torch.
|
1737
|
-
encoder_hidden_states: Optional[torch.
|
1738
|
-
attention_mask: Optional[torch.
|
1739
|
-
temb: Optional[torch.
|
1740
|
-
) -> torch.
|
1863
|
+
hidden_states: torch.Tensor,
|
1864
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1865
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1866
|
+
temb: Optional[torch.Tensor] = None,
|
1867
|
+
) -> torch.Tensor:
|
1741
1868
|
residual = hidden_states
|
1742
1869
|
|
1743
1870
|
if attn.spatial_norm is not None:
|
@@ -1830,7 +1957,7 @@ class SpatialNorm(nn.Module):
|
|
1830
1957
|
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1831
1958
|
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1832
1959
|
|
1833
|
-
def forward(self, f: torch.
|
1960
|
+
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
1834
1961
|
f_size = f.shape[-2:]
|
1835
1962
|
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1836
1963
|
norm_f = self.norm_layer(f)
|
@@ -1876,7 +2003,7 @@ class LoRAAttnProcessor(nn.Module):
|
|
1876
2003
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1877
2004
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1878
2005
|
|
1879
|
-
def __call__(self, attn: Attention, hidden_states: torch.
|
2006
|
+
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
1880
2007
|
self_cls_name = self.__class__.__name__
|
1881
2008
|
deprecate(
|
1882
2009
|
self_cls_name,
|
@@ -1937,7 +2064,7 @@ class LoRAAttnProcessor2_0(nn.Module):
|
|
1937
2064
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
1938
2065
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
1939
2066
|
|
1940
|
-
def __call__(self, attn: Attention, hidden_states: torch.
|
2067
|
+
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
1941
2068
|
self_cls_name = self.__class__.__name__
|
1942
2069
|
deprecate(
|
1943
2070
|
self_cls_name,
|
@@ -2016,7 +2143,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
|
2016
2143
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
2017
2144
|
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
2018
2145
|
|
2019
|
-
def __call__(self, attn: Attention, hidden_states: torch.
|
2146
|
+
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2020
2147
|
self_cls_name = self.__class__.__name__
|
2021
2148
|
deprecate(
|
2022
2149
|
self_cls_name,
|
@@ -2075,7 +2202,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|
2075
2202
|
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2076
2203
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2077
2204
|
|
2078
|
-
def __call__(self, attn: Attention, hidden_states: torch.
|
2205
|
+
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2079
2206
|
self_cls_name = self.__class__.__name__
|
2080
2207
|
deprecate(
|
2081
2208
|
self_cls_name,
|
@@ -2098,7 +2225,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
|
2098
2225
|
|
2099
2226
|
class IPAdapterAttnProcessor(nn.Module):
|
2100
2227
|
r"""
|
2101
|
-
Attention processor for Multiple IP-
|
2228
|
+
Attention processor for Multiple IP-Adapters.
|
2102
2229
|
|
2103
2230
|
Args:
|
2104
2231
|
hidden_size (`int`):
|
@@ -2137,12 +2264,12 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2137
2264
|
def __call__(
|
2138
2265
|
self,
|
2139
2266
|
attn: Attention,
|
2140
|
-
hidden_states: torch.
|
2141
|
-
encoder_hidden_states: Optional[torch.
|
2142
|
-
attention_mask: Optional[torch.
|
2143
|
-
temb: Optional[torch.
|
2267
|
+
hidden_states: torch.Tensor,
|
2268
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2269
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2270
|
+
temb: Optional[torch.Tensor] = None,
|
2144
2271
|
scale: float = 1.0,
|
2145
|
-
ip_adapter_masks: Optional[torch.
|
2272
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
2146
2273
|
):
|
2147
2274
|
residual = hidden_states
|
2148
2275
|
|
@@ -2152,8 +2279,8 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2152
2279
|
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
2153
2280
|
else:
|
2154
2281
|
deprecation_message = (
|
2155
|
-
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
2156
|
-
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to
|
2282
|
+
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
2283
|
+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
2157
2284
|
)
|
2158
2285
|
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
2159
2286
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
@@ -2198,15 +2325,33 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2198
2325
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
2199
2326
|
|
2200
2327
|
if ip_adapter_masks is not None:
|
2201
|
-
if not isinstance(ip_adapter_masks,
|
2328
|
+
if not isinstance(ip_adapter_masks, List):
|
2329
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
2330
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
2331
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
2202
2332
|
raise ValueError(
|
2203
|
-
"
|
2204
|
-
"
|
2205
|
-
|
2206
|
-
if len(ip_adapter_masks) != len(self.scale):
|
2207
|
-
raise ValueError(
|
2208
|
-
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
|
2333
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
2334
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
2335
|
+
f"({len(ip_hidden_states)})"
|
2209
2336
|
)
|
2337
|
+
else:
|
2338
|
+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
2339
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
2340
|
+
raise ValueError(
|
2341
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
2342
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
2343
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
2344
|
+
)
|
2345
|
+
if mask.shape[1] != ip_state.shape[1]:
|
2346
|
+
raise ValueError(
|
2347
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
2348
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
2349
|
+
)
|
2350
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
2351
|
+
raise ValueError(
|
2352
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
2353
|
+
f"number of scales ({len(scale)}) at index {index}"
|
2354
|
+
)
|
2210
2355
|
else:
|
2211
2356
|
ip_adapter_masks = [None] * len(self.scale)
|
2212
2357
|
|
@@ -2214,26 +2359,51 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2214
2359
|
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
2215
2360
|
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
2216
2361
|
):
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2223
|
-
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
|
2228
|
-
|
2229
|
-
|
2230
|
-
|
2231
|
-
|
2232
|
-
|
2233
|
-
|
2234
|
-
|
2235
|
-
|
2236
|
-
|
2362
|
+
skip = False
|
2363
|
+
if isinstance(scale, list):
|
2364
|
+
if all(s == 0 for s in scale):
|
2365
|
+
skip = True
|
2366
|
+
elif scale == 0:
|
2367
|
+
skip = True
|
2368
|
+
if not skip:
|
2369
|
+
if mask is not None:
|
2370
|
+
if not isinstance(scale, list):
|
2371
|
+
scale = [scale] * mask.shape[1]
|
2372
|
+
|
2373
|
+
current_num_images = mask.shape[1]
|
2374
|
+
for i in range(current_num_images):
|
2375
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
2376
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
2377
|
+
|
2378
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
2379
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
2380
|
+
|
2381
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
2382
|
+
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
2383
|
+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
2384
|
+
|
2385
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
2386
|
+
mask[:, i, :, :],
|
2387
|
+
batch_size,
|
2388
|
+
_current_ip_hidden_states.shape[1],
|
2389
|
+
_current_ip_hidden_states.shape[2],
|
2390
|
+
)
|
2391
|
+
|
2392
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
2393
|
+
|
2394
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
2395
|
+
else:
|
2396
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
2397
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
2398
|
+
|
2399
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
2400
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
2401
|
+
|
2402
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
2403
|
+
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
2404
|
+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
2405
|
+
|
2406
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
2237
2407
|
|
2238
2408
|
# linear proj
|
2239
2409
|
hidden_states = attn.to_out[0](hidden_states)
|
@@ -2253,7 +2423,7 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
2253
2423
|
|
2254
2424
|
class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
2255
2425
|
r"""
|
2256
|
-
Attention processor for IP-
|
2426
|
+
Attention processor for IP-Adapter for PyTorch 2.0.
|
2257
2427
|
|
2258
2428
|
Args:
|
2259
2429
|
hidden_size (`int`):
|
@@ -2297,12 +2467,12 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2297
2467
|
def __call__(
|
2298
2468
|
self,
|
2299
2469
|
attn: Attention,
|
2300
|
-
hidden_states: torch.
|
2301
|
-
encoder_hidden_states: Optional[torch.
|
2302
|
-
attention_mask: Optional[torch.
|
2303
|
-
temb: Optional[torch.
|
2470
|
+
hidden_states: torch.Tensor,
|
2471
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2472
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2473
|
+
temb: Optional[torch.Tensor] = None,
|
2304
2474
|
scale: float = 1.0,
|
2305
|
-
ip_adapter_masks: Optional[torch.
|
2475
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
2306
2476
|
):
|
2307
2477
|
residual = hidden_states
|
2308
2478
|
|
@@ -2312,8 +2482,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2312
2482
|
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
2313
2483
|
else:
|
2314
2484
|
deprecation_message = (
|
2315
|
-
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
|
2316
|
-
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to
|
2485
|
+
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
2486
|
+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
2317
2487
|
)
|
2318
2488
|
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
2319
2489
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
@@ -2372,15 +2542,33 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2372
2542
|
hidden_states = hidden_states.to(query.dtype)
|
2373
2543
|
|
2374
2544
|
if ip_adapter_masks is not None:
|
2375
|
-
if not isinstance(ip_adapter_masks,
|
2376
|
-
|
2377
|
-
|
2378
|
-
|
2379
|
-
)
|
2380
|
-
if len(ip_adapter_masks) != len(self.scale):
|
2545
|
+
if not isinstance(ip_adapter_masks, List):
|
2546
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
2547
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
2548
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
2381
2549
|
raise ValueError(
|
2382
|
-
f"
|
2550
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
2551
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
2552
|
+
f"({len(ip_hidden_states)})"
|
2383
2553
|
)
|
2554
|
+
else:
|
2555
|
+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
2556
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
2557
|
+
raise ValueError(
|
2558
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
2559
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
2560
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
2561
|
+
)
|
2562
|
+
if mask.shape[1] != ip_state.shape[1]:
|
2563
|
+
raise ValueError(
|
2564
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
2565
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
2566
|
+
)
|
2567
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
2568
|
+
raise ValueError(
|
2569
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
2570
|
+
f"number of scales ({len(scale)}) at index {index}"
|
2571
|
+
)
|
2384
2572
|
else:
|
2385
2573
|
ip_adapter_masks = [None] * len(self.scale)
|
2386
2574
|
|
@@ -2388,33 +2576,64 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2388
2576
|
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
2389
2577
|
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
2390
2578
|
):
|
2391
|
-
|
2392
|
-
|
2393
|
-
|
2394
|
-
|
2395
|
-
|
2396
|
-
|
2397
|
-
|
2398
|
-
|
2399
|
-
|
2400
|
-
|
2401
|
-
|
2402
|
-
|
2403
|
-
|
2404
|
-
|
2405
|
-
|
2406
|
-
|
2407
|
-
|
2408
|
-
|
2409
|
-
|
2410
|
-
|
2411
|
-
|
2412
|
-
|
2413
|
-
|
2579
|
+
skip = False
|
2580
|
+
if isinstance(scale, list):
|
2581
|
+
if all(s == 0 for s in scale):
|
2582
|
+
skip = True
|
2583
|
+
elif scale == 0:
|
2584
|
+
skip = True
|
2585
|
+
if not skip:
|
2586
|
+
if mask is not None:
|
2587
|
+
if not isinstance(scale, list):
|
2588
|
+
scale = [scale] * mask.shape[1]
|
2589
|
+
|
2590
|
+
current_num_images = mask.shape[1]
|
2591
|
+
for i in range(current_num_images):
|
2592
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
2593
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
2594
|
+
|
2595
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2596
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2597
|
+
|
2598
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2599
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2600
|
+
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
2601
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2602
|
+
)
|
2603
|
+
|
2604
|
+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
2605
|
+
batch_size, -1, attn.heads * head_dim
|
2606
|
+
)
|
2607
|
+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
2608
|
+
|
2609
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
2610
|
+
mask[:, i, :, :],
|
2611
|
+
batch_size,
|
2612
|
+
_current_ip_hidden_states.shape[1],
|
2613
|
+
_current_ip_hidden_states.shape[2],
|
2614
|
+
)
|
2615
|
+
|
2616
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
2617
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
2618
|
+
else:
|
2619
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
2620
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
2621
|
+
|
2622
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2623
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2624
|
+
|
2625
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2626
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2627
|
+
current_ip_hidden_states = F.scaled_dot_product_attention(
|
2628
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2629
|
+
)
|
2414
2630
|
|
2415
|
-
|
2631
|
+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
2632
|
+
batch_size, -1, attn.heads * head_dim
|
2633
|
+
)
|
2634
|
+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
2416
2635
|
|
2417
|
-
|
2636
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
2418
2637
|
|
2419
2638
|
# linear proj
|
2420
2639
|
hidden_states = attn.to_out[0](hidden_states)
|