diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,8 @@ import torch.nn.functional as F
|
|
20
20
|
from torch import nn
|
21
21
|
|
22
22
|
from ..image_processor import IPAdapterMaskProcessor
|
23
|
-
from ..utils import deprecate, logging
|
24
|
-
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
23
|
+
from ..utils import deprecate, is_torch_xla_available, logging
|
24
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
|
25
25
|
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
26
26
|
|
27
27
|
|
@@ -36,6 +36,15 @@ if is_xformers_available():
|
|
36
36
|
else:
|
37
37
|
xformers = None
|
38
38
|
|
39
|
+
if is_torch_xla_available():
|
40
|
+
# flash attention pallas kernel is introduced in the torch_xla 2.3 release.
|
41
|
+
if is_torch_xla_version(">", "2.2"):
|
42
|
+
from torch_xla.experimental.custom_kernel import flash_attention
|
43
|
+
from torch_xla.runtime import is_spmd
|
44
|
+
XLA_AVAILABLE = True
|
45
|
+
else:
|
46
|
+
XLA_AVAILABLE = False
|
47
|
+
|
39
48
|
|
40
49
|
@maybe_allow_in_graph
|
41
50
|
class Attention(nn.Module):
|
@@ -120,13 +129,16 @@ class Attention(nn.Module):
|
|
120
129
|
_from_deprecated_attn_block: bool = False,
|
121
130
|
processor: Optional["AttnProcessor"] = None,
|
122
131
|
out_dim: int = None,
|
132
|
+
out_context_dim: int = None,
|
123
133
|
context_pre_only=None,
|
124
134
|
pre_only=False,
|
135
|
+
elementwise_affine: bool = True,
|
136
|
+
is_causal: bool = False,
|
125
137
|
):
|
126
138
|
super().__init__()
|
127
139
|
|
128
140
|
# To prevent circular import.
|
129
|
-
from .normalization import FP32LayerNorm, RMSNorm
|
141
|
+
from .normalization import FP32LayerNorm, LpNorm, RMSNorm
|
130
142
|
|
131
143
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
132
144
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
@@ -141,8 +153,10 @@ class Attention(nn.Module):
|
|
141
153
|
self.dropout = dropout
|
142
154
|
self.fused_projections = False
|
143
155
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
156
|
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
144
157
|
self.context_pre_only = context_pre_only
|
145
158
|
self.pre_only = pre_only
|
159
|
+
self.is_causal = is_causal
|
146
160
|
|
147
161
|
# we make use of this private variable to know whether this class is loaded
|
148
162
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -179,20 +193,27 @@ class Attention(nn.Module):
|
|
179
193
|
self.norm_q = None
|
180
194
|
self.norm_k = None
|
181
195
|
elif qk_norm == "layer_norm":
|
182
|
-
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
183
|
-
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
196
|
+
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
197
|
+
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
184
198
|
elif qk_norm == "fp32_layer_norm":
|
185
199
|
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
186
200
|
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
187
201
|
elif qk_norm == "layer_norm_across_heads":
|
188
|
-
# Lumina
|
202
|
+
# Lumina applies qk norm across all heads
|
189
203
|
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
190
204
|
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
191
205
|
elif qk_norm == "rms_norm":
|
192
206
|
self.norm_q = RMSNorm(dim_head, eps=eps)
|
193
207
|
self.norm_k = RMSNorm(dim_head, eps=eps)
|
208
|
+
elif qk_norm == "rms_norm_across_heads":
|
209
|
+
# LTX applies qk norm across all heads
|
210
|
+
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
211
|
+
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
212
|
+
elif qk_norm == "l2":
|
213
|
+
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
214
|
+
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
194
215
|
else:
|
195
|
-
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None
|
216
|
+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
196
217
|
|
197
218
|
if cross_attention_norm is None:
|
198
219
|
self.norm_cross = None
|
@@ -233,14 +254,22 @@ class Attention(nn.Module):
|
|
233
254
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
234
255
|
if self.context_pre_only is not None:
|
235
256
|
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
257
|
+
else:
|
258
|
+
self.add_q_proj = None
|
259
|
+
self.add_k_proj = None
|
260
|
+
self.add_v_proj = None
|
236
261
|
|
237
262
|
if not self.pre_only:
|
238
263
|
self.to_out = nn.ModuleList([])
|
239
264
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
240
265
|
self.to_out.append(nn.Dropout(dropout))
|
266
|
+
else:
|
267
|
+
self.to_out = None
|
241
268
|
|
242
269
|
if self.context_pre_only is not None and not self.context_pre_only:
|
243
|
-
self.to_add_out = nn.Linear(self.inner_dim, self.
|
270
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
271
|
+
else:
|
272
|
+
self.to_add_out = None
|
244
273
|
|
245
274
|
if qk_norm is not None and added_kv_proj_dim is not None:
|
246
275
|
if qk_norm == "fp32_layer_norm":
|
@@ -249,6 +278,10 @@ class Attention(nn.Module):
|
|
249
278
|
elif qk_norm == "rms_norm":
|
250
279
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
251
280
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
281
|
+
else:
|
282
|
+
raise ValueError(
|
283
|
+
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
284
|
+
)
|
252
285
|
else:
|
253
286
|
self.norm_added_q = None
|
254
287
|
self.norm_added_k = None
|
@@ -263,6 +296,33 @@ class Attention(nn.Module):
|
|
263
296
|
)
|
264
297
|
self.set_processor(processor)
|
265
298
|
|
299
|
+
def set_use_xla_flash_attention(
|
300
|
+
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
|
301
|
+
) -> None:
|
302
|
+
r"""
|
303
|
+
Set whether to use xla flash attention from `torch_xla` or not.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
use_xla_flash_attention (`bool`):
|
307
|
+
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
308
|
+
partition_spec (`Tuple[]`, *optional*):
|
309
|
+
Specify the partition specification if using SPMD. Otherwise None.
|
310
|
+
"""
|
311
|
+
if use_xla_flash_attention:
|
312
|
+
if not is_torch_xla_available:
|
313
|
+
raise "torch_xla is not available"
|
314
|
+
elif is_torch_xla_version("<", "2.3"):
|
315
|
+
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
316
|
+
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
317
|
+
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
318
|
+
else:
|
319
|
+
processor = XLAFlashAttnProcessor2_0(partition_spec)
|
320
|
+
else:
|
321
|
+
processor = (
|
322
|
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
323
|
+
)
|
324
|
+
self.set_processor(processor)
|
325
|
+
|
266
326
|
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
267
327
|
r"""
|
268
328
|
Set whether to use npu flash attention from `torch_npu` or not.
|
@@ -306,6 +366,17 @@ class Attention(nn.Module):
|
|
306
366
|
XFormersAttnAddedKVProcessor,
|
307
367
|
),
|
308
368
|
)
|
369
|
+
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
370
|
+
self.processor,
|
371
|
+
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
372
|
+
)
|
373
|
+
is_joint_processor = hasattr(self, "processor") and isinstance(
|
374
|
+
self.processor,
|
375
|
+
(
|
376
|
+
JointAttnProcessor2_0,
|
377
|
+
XFormersJointAttnProcessor,
|
378
|
+
),
|
379
|
+
)
|
309
380
|
|
310
381
|
if use_memory_efficient_attention_xformers:
|
311
382
|
if is_added_kv_processor and is_custom_diffusion:
|
@@ -356,6 +427,21 @@ class Attention(nn.Module):
|
|
356
427
|
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
357
428
|
)
|
358
429
|
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
430
|
+
elif is_ip_adapter:
|
431
|
+
processor = IPAdapterXFormersAttnProcessor(
|
432
|
+
hidden_size=self.processor.hidden_size,
|
433
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
434
|
+
num_tokens=self.processor.num_tokens,
|
435
|
+
scale=self.processor.scale,
|
436
|
+
attention_op=attention_op,
|
437
|
+
)
|
438
|
+
processor.load_state_dict(self.processor.state_dict())
|
439
|
+
if hasattr(self.processor, "to_k_ip"):
|
440
|
+
processor.to(
|
441
|
+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
442
|
+
)
|
443
|
+
elif is_joint_processor:
|
444
|
+
processor = XFormersJointAttnProcessor(attention_op=attention_op)
|
359
445
|
else:
|
360
446
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
361
447
|
else:
|
@@ -374,6 +460,18 @@ class Attention(nn.Module):
|
|
374
460
|
processor.load_state_dict(self.processor.state_dict())
|
375
461
|
if hasattr(self.processor, "to_k_custom_diffusion"):
|
376
462
|
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
463
|
+
elif is_ip_adapter:
|
464
|
+
processor = IPAdapterAttnProcessor2_0(
|
465
|
+
hidden_size=self.processor.hidden_size,
|
466
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
467
|
+
num_tokens=self.processor.num_tokens,
|
468
|
+
scale=self.processor.scale,
|
469
|
+
)
|
470
|
+
processor.load_state_dict(self.processor.state_dict())
|
471
|
+
if hasattr(self.processor, "to_k_ip"):
|
472
|
+
processor.to(
|
473
|
+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
474
|
+
)
|
377
475
|
else:
|
378
476
|
# set attention processor
|
379
477
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
@@ -477,7 +575,7 @@ class Attention(nn.Module):
|
|
477
575
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
478
576
|
|
479
577
|
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
480
|
-
quiet_attn_parameters = {"ip_adapter_masks"}
|
578
|
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
481
579
|
unused_kwargs = [
|
482
580
|
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
483
581
|
]
|
@@ -692,7 +790,11 @@ class Attention(nn.Module):
|
|
692
790
|
self.to_kv.bias.copy_(concatenated_bias)
|
693
791
|
|
694
792
|
# handle added projections for SD3 and others.
|
695
|
-
if
|
793
|
+
if (
|
794
|
+
getattr(self, "add_q_proj", None) is not None
|
795
|
+
and getattr(self, "add_k_proj", None) is not None
|
796
|
+
and getattr(self, "add_v_proj", None) is not None
|
797
|
+
):
|
696
798
|
concatenated_weights = torch.cat(
|
697
799
|
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
698
800
|
)
|
@@ -712,6 +814,269 @@ class Attention(nn.Module):
|
|
712
814
|
self.fused_projections = fuse
|
713
815
|
|
714
816
|
|
817
|
+
class SanaMultiscaleAttentionProjection(nn.Module):
|
818
|
+
def __init__(
|
819
|
+
self,
|
820
|
+
in_channels: int,
|
821
|
+
num_attention_heads: int,
|
822
|
+
kernel_size: int,
|
823
|
+
) -> None:
|
824
|
+
super().__init__()
|
825
|
+
|
826
|
+
channels = 3 * in_channels
|
827
|
+
self.proj_in = nn.Conv2d(
|
828
|
+
channels,
|
829
|
+
channels,
|
830
|
+
kernel_size,
|
831
|
+
padding=kernel_size // 2,
|
832
|
+
groups=channels,
|
833
|
+
bias=False,
|
834
|
+
)
|
835
|
+
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
836
|
+
|
837
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
838
|
+
hidden_states = self.proj_in(hidden_states)
|
839
|
+
hidden_states = self.proj_out(hidden_states)
|
840
|
+
return hidden_states
|
841
|
+
|
842
|
+
|
843
|
+
class SanaMultiscaleLinearAttention(nn.Module):
|
844
|
+
r"""Lightweight multi-scale linear attention"""
|
845
|
+
|
846
|
+
def __init__(
|
847
|
+
self,
|
848
|
+
in_channels: int,
|
849
|
+
out_channels: int,
|
850
|
+
num_attention_heads: Optional[int] = None,
|
851
|
+
attention_head_dim: int = 8,
|
852
|
+
mult: float = 1.0,
|
853
|
+
norm_type: str = "batch_norm",
|
854
|
+
kernel_sizes: Tuple[int, ...] = (5,),
|
855
|
+
eps: float = 1e-15,
|
856
|
+
residual_connection: bool = False,
|
857
|
+
):
|
858
|
+
super().__init__()
|
859
|
+
|
860
|
+
# To prevent circular import
|
861
|
+
from .normalization import get_normalization
|
862
|
+
|
863
|
+
self.eps = eps
|
864
|
+
self.attention_head_dim = attention_head_dim
|
865
|
+
self.norm_type = norm_type
|
866
|
+
self.residual_connection = residual_connection
|
867
|
+
|
868
|
+
num_attention_heads = (
|
869
|
+
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
870
|
+
)
|
871
|
+
inner_dim = num_attention_heads * attention_head_dim
|
872
|
+
|
873
|
+
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
874
|
+
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
875
|
+
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
876
|
+
|
877
|
+
self.to_qkv_multiscale = nn.ModuleList()
|
878
|
+
for kernel_size in kernel_sizes:
|
879
|
+
self.to_qkv_multiscale.append(
|
880
|
+
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
881
|
+
)
|
882
|
+
|
883
|
+
self.nonlinearity = nn.ReLU()
|
884
|
+
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
885
|
+
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
886
|
+
|
887
|
+
self.processor = SanaMultiscaleAttnProcessor2_0()
|
888
|
+
|
889
|
+
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
890
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
891
|
+
scores = torch.matmul(value, key.transpose(-1, -2))
|
892
|
+
hidden_states = torch.matmul(scores, query)
|
893
|
+
|
894
|
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
895
|
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
896
|
+
return hidden_states
|
897
|
+
|
898
|
+
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
899
|
+
scores = torch.matmul(key.transpose(-1, -2), query)
|
900
|
+
scores = scores.to(dtype=torch.float32)
|
901
|
+
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
902
|
+
hidden_states = torch.matmul(value, scores)
|
903
|
+
return hidden_states
|
904
|
+
|
905
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
906
|
+
return self.processor(self, hidden_states)
|
907
|
+
|
908
|
+
|
909
|
+
class MochiAttention(nn.Module):
|
910
|
+
def __init__(
|
911
|
+
self,
|
912
|
+
query_dim: int,
|
913
|
+
added_kv_proj_dim: int,
|
914
|
+
processor: "MochiAttnProcessor2_0",
|
915
|
+
heads: int = 8,
|
916
|
+
dim_head: int = 64,
|
917
|
+
dropout: float = 0.0,
|
918
|
+
bias: bool = False,
|
919
|
+
added_proj_bias: bool = True,
|
920
|
+
out_dim: Optional[int] = None,
|
921
|
+
out_context_dim: Optional[int] = None,
|
922
|
+
out_bias: bool = True,
|
923
|
+
context_pre_only: bool = False,
|
924
|
+
eps: float = 1e-5,
|
925
|
+
):
|
926
|
+
super().__init__()
|
927
|
+
from .normalization import MochiRMSNorm
|
928
|
+
|
929
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
930
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
931
|
+
self.out_context_dim = out_context_dim if out_context_dim else query_dim
|
932
|
+
self.context_pre_only = context_pre_only
|
933
|
+
|
934
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
935
|
+
|
936
|
+
self.norm_q = MochiRMSNorm(dim_head, eps, True)
|
937
|
+
self.norm_k = MochiRMSNorm(dim_head, eps, True)
|
938
|
+
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
|
939
|
+
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
|
940
|
+
|
941
|
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
942
|
+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
943
|
+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
944
|
+
|
945
|
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
946
|
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
947
|
+
if self.context_pre_only is not None:
|
948
|
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
949
|
+
|
950
|
+
self.to_out = nn.ModuleList([])
|
951
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
952
|
+
self.to_out.append(nn.Dropout(dropout))
|
953
|
+
|
954
|
+
if not self.context_pre_only:
|
955
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
956
|
+
|
957
|
+
self.processor = processor
|
958
|
+
|
959
|
+
def forward(
|
960
|
+
self,
|
961
|
+
hidden_states: torch.Tensor,
|
962
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
963
|
+
attention_mask: Optional[torch.Tensor] = None,
|
964
|
+
**kwargs,
|
965
|
+
):
|
966
|
+
return self.processor(
|
967
|
+
self,
|
968
|
+
hidden_states,
|
969
|
+
encoder_hidden_states=encoder_hidden_states,
|
970
|
+
attention_mask=attention_mask,
|
971
|
+
**kwargs,
|
972
|
+
)
|
973
|
+
|
974
|
+
|
975
|
+
class MochiAttnProcessor2_0:
|
976
|
+
"""Attention processor used in Mochi."""
|
977
|
+
|
978
|
+
def __init__(self):
|
979
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
980
|
+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
981
|
+
|
982
|
+
def __call__(
|
983
|
+
self,
|
984
|
+
attn: "MochiAttention",
|
985
|
+
hidden_states: torch.Tensor,
|
986
|
+
encoder_hidden_states: torch.Tensor,
|
987
|
+
attention_mask: torch.Tensor,
|
988
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
989
|
+
) -> torch.Tensor:
|
990
|
+
query = attn.to_q(hidden_states)
|
991
|
+
key = attn.to_k(hidden_states)
|
992
|
+
value = attn.to_v(hidden_states)
|
993
|
+
|
994
|
+
query = query.unflatten(2, (attn.heads, -1))
|
995
|
+
key = key.unflatten(2, (attn.heads, -1))
|
996
|
+
value = value.unflatten(2, (attn.heads, -1))
|
997
|
+
|
998
|
+
if attn.norm_q is not None:
|
999
|
+
query = attn.norm_q(query)
|
1000
|
+
if attn.norm_k is not None:
|
1001
|
+
key = attn.norm_k(key)
|
1002
|
+
|
1003
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
1004
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
1005
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
1006
|
+
|
1007
|
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
1008
|
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
1009
|
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
1010
|
+
|
1011
|
+
if attn.norm_added_q is not None:
|
1012
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
1013
|
+
if attn.norm_added_k is not None:
|
1014
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
1015
|
+
|
1016
|
+
if image_rotary_emb is not None:
|
1017
|
+
|
1018
|
+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
1019
|
+
x_even = x[..., 0::2].float()
|
1020
|
+
x_odd = x[..., 1::2].float()
|
1021
|
+
|
1022
|
+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
1023
|
+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
1024
|
+
|
1025
|
+
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
1026
|
+
|
1027
|
+
query = apply_rotary_emb(query, *image_rotary_emb)
|
1028
|
+
key = apply_rotary_emb(key, *image_rotary_emb)
|
1029
|
+
|
1030
|
+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
1031
|
+
encoder_query, encoder_key, encoder_value = (
|
1032
|
+
encoder_query.transpose(1, 2),
|
1033
|
+
encoder_key.transpose(1, 2),
|
1034
|
+
encoder_value.transpose(1, 2),
|
1035
|
+
)
|
1036
|
+
|
1037
|
+
sequence_length = query.size(2)
|
1038
|
+
encoder_sequence_length = encoder_query.size(2)
|
1039
|
+
total_length = sequence_length + encoder_sequence_length
|
1040
|
+
|
1041
|
+
batch_size, heads, _, dim = query.shape
|
1042
|
+
attn_outputs = []
|
1043
|
+
for idx in range(batch_size):
|
1044
|
+
mask = attention_mask[idx][None, :]
|
1045
|
+
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
|
1046
|
+
|
1047
|
+
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1048
|
+
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1049
|
+
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1050
|
+
|
1051
|
+
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
|
1052
|
+
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
|
1053
|
+
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
|
1054
|
+
|
1055
|
+
attn_output = F.scaled_dot_product_attention(
|
1056
|
+
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
|
1057
|
+
)
|
1058
|
+
valid_sequence_length = attn_output.size(2)
|
1059
|
+
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
|
1060
|
+
attn_outputs.append(attn_output)
|
1061
|
+
|
1062
|
+
hidden_states = torch.cat(attn_outputs, dim=0)
|
1063
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
1064
|
+
|
1065
|
+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
1066
|
+
(sequence_length, encoder_sequence_length), dim=1
|
1067
|
+
)
|
1068
|
+
|
1069
|
+
# linear proj
|
1070
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1071
|
+
# dropout
|
1072
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1073
|
+
|
1074
|
+
if hasattr(attn, "to_add_out"):
|
1075
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1076
|
+
|
1077
|
+
return hidden_states, encoder_hidden_states
|
1078
|
+
|
1079
|
+
|
715
1080
|
class AttnProcessor:
|
716
1081
|
r"""
|
717
1082
|
Default processor for performing attention-related computations.
|
@@ -1049,61 +1414,72 @@ class JointAttnProcessor2_0:
|
|
1049
1414
|
) -> torch.FloatTensor:
|
1050
1415
|
residual = hidden_states
|
1051
1416
|
|
1052
|
-
|
1053
|
-
if input_ndim == 4:
|
1054
|
-
batch_size, channel, height, width = hidden_states.shape
|
1055
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1056
|
-
context_input_ndim = encoder_hidden_states.ndim
|
1057
|
-
if context_input_ndim == 4:
|
1058
|
-
batch_size, channel, height, width = encoder_hidden_states.shape
|
1059
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1060
|
-
|
1061
|
-
batch_size = encoder_hidden_states.shape[0]
|
1417
|
+
batch_size = hidden_states.shape[0]
|
1062
1418
|
|
1063
1419
|
# `sample` projections.
|
1064
1420
|
query = attn.to_q(hidden_states)
|
1065
1421
|
key = attn.to_k(hidden_states)
|
1066
1422
|
value = attn.to_v(hidden_states)
|
1067
1423
|
|
1068
|
-
# `context` projections.
|
1069
|
-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1070
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1071
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1072
|
-
|
1073
|
-
# attention
|
1074
|
-
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1075
|
-
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1076
|
-
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1077
|
-
|
1078
1424
|
inner_dim = key.shape[-1]
|
1079
1425
|
head_dim = inner_dim // attn.heads
|
1426
|
+
|
1080
1427
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1081
1428
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1082
1429
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1083
1430
|
|
1431
|
+
if attn.norm_q is not None:
|
1432
|
+
query = attn.norm_q(query)
|
1433
|
+
if attn.norm_k is not None:
|
1434
|
+
key = attn.norm_k(key)
|
1435
|
+
|
1436
|
+
# `context` projections.
|
1437
|
+
if encoder_hidden_states is not None:
|
1438
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1439
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1440
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1441
|
+
|
1442
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1443
|
+
batch_size, -1, attn.heads, head_dim
|
1444
|
+
).transpose(1, 2)
|
1445
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
1446
|
+
batch_size, -1, attn.heads, head_dim
|
1447
|
+
).transpose(1, 2)
|
1448
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1449
|
+
batch_size, -1, attn.heads, head_dim
|
1450
|
+
).transpose(1, 2)
|
1451
|
+
|
1452
|
+
if attn.norm_added_q is not None:
|
1453
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1454
|
+
if attn.norm_added_k is not None:
|
1455
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1456
|
+
|
1457
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
1458
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
1459
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
1460
|
+
|
1084
1461
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1085
1462
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1086
1463
|
hidden_states = hidden_states.to(query.dtype)
|
1087
1464
|
|
1088
|
-
|
1089
|
-
|
1090
|
-
hidden_states
|
1091
|
-
|
1092
|
-
|
1465
|
+
if encoder_hidden_states is not None:
|
1466
|
+
# Split the attention outputs.
|
1467
|
+
hidden_states, encoder_hidden_states = (
|
1468
|
+
hidden_states[:, : residual.shape[1]],
|
1469
|
+
hidden_states[:, residual.shape[1] :],
|
1470
|
+
)
|
1471
|
+
if not attn.context_pre_only:
|
1472
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1093
1473
|
|
1094
1474
|
# linear proj
|
1095
1475
|
hidden_states = attn.to_out[0](hidden_states)
|
1096
1476
|
# dropout
|
1097
1477
|
hidden_states = attn.to_out[1](hidden_states)
|
1098
|
-
if not attn.context_pre_only:
|
1099
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1100
|
-
|
1101
|
-
if input_ndim == 4:
|
1102
|
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1103
|
-
if context_input_ndim == 4:
|
1104
|
-
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1105
1478
|
|
1106
|
-
|
1479
|
+
if encoder_hidden_states is not None:
|
1480
|
+
return hidden_states, encoder_hidden_states
|
1481
|
+
else:
|
1482
|
+
return hidden_states
|
1107
1483
|
|
1108
1484
|
|
1109
1485
|
class PAGJointAttnProcessor2_0:
|
@@ -1120,6 +1496,7 @@ class PAGJointAttnProcessor2_0:
|
|
1120
1496
|
attn: Attention,
|
1121
1497
|
hidden_states: torch.FloatTensor,
|
1122
1498
|
encoder_hidden_states: torch.FloatTensor = None,
|
1499
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1123
1500
|
) -> torch.FloatTensor:
|
1124
1501
|
residual = hidden_states
|
1125
1502
|
|
@@ -1505,34 +1882,213 @@ class FusedJointAttnProcessor2_0:
|
|
1505
1882
|
return hidden_states, encoder_hidden_states
|
1506
1883
|
|
1507
1884
|
|
1508
|
-
class
|
1509
|
-
"""
|
1885
|
+
class XFormersJointAttnProcessor:
|
1886
|
+
r"""
|
1887
|
+
Processor for implementing memory efficient attention using xFormers.
|
1510
1888
|
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1889
|
+
Args:
|
1890
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1891
|
+
The base
|
1892
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1893
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1894
|
+
operator.
|
1895
|
+
"""
|
1896
|
+
|
1897
|
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1898
|
+
self.attention_op = attention_op
|
1516
1899
|
|
1517
1900
|
def __call__(
|
1518
1901
|
self,
|
1519
1902
|
attn: Attention,
|
1520
1903
|
hidden_states: torch.FloatTensor,
|
1521
1904
|
encoder_hidden_states: torch.FloatTensor = None,
|
1905
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1522
1906
|
*args,
|
1523
1907
|
**kwargs,
|
1524
1908
|
) -> torch.FloatTensor:
|
1525
|
-
|
1909
|
+
residual = hidden_states
|
1526
1910
|
|
1527
1911
|
# `sample` projections.
|
1528
1912
|
query = attn.to_q(hidden_states)
|
1529
1913
|
key = attn.to_k(hidden_states)
|
1530
1914
|
value = attn.to_v(hidden_states)
|
1531
1915
|
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1916
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
1917
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
1918
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
1919
|
+
|
1920
|
+
if attn.norm_q is not None:
|
1921
|
+
query = attn.norm_q(query)
|
1922
|
+
if attn.norm_k is not None:
|
1923
|
+
key = attn.norm_k(key)
|
1924
|
+
|
1925
|
+
# `context` projections.
|
1926
|
+
if encoder_hidden_states is not None:
|
1927
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1928
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1929
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1930
|
+
|
1931
|
+
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
|
1932
|
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
|
1933
|
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
|
1934
|
+
|
1935
|
+
if attn.norm_added_q is not None:
|
1936
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1937
|
+
if attn.norm_added_k is not None:
|
1938
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1939
|
+
|
1940
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1941
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1942
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1943
|
+
|
1944
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1945
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1946
|
+
)
|
1947
|
+
hidden_states = hidden_states.to(query.dtype)
|
1948
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1949
|
+
|
1950
|
+
if encoder_hidden_states is not None:
|
1951
|
+
# Split the attention outputs.
|
1952
|
+
hidden_states, encoder_hidden_states = (
|
1953
|
+
hidden_states[:, : residual.shape[1]],
|
1954
|
+
hidden_states[:, residual.shape[1] :],
|
1955
|
+
)
|
1956
|
+
if not attn.context_pre_only:
|
1957
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1958
|
+
|
1959
|
+
# linear proj
|
1960
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1961
|
+
# dropout
|
1962
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1963
|
+
|
1964
|
+
if encoder_hidden_states is not None:
|
1965
|
+
return hidden_states, encoder_hidden_states
|
1966
|
+
else:
|
1967
|
+
return hidden_states
|
1968
|
+
|
1969
|
+
|
1970
|
+
class AllegroAttnProcessor2_0:
|
1971
|
+
r"""
|
1972
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
1973
|
+
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
|
1974
|
+
"""
|
1975
|
+
|
1976
|
+
def __init__(self):
|
1977
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1978
|
+
raise ImportError(
|
1979
|
+
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1980
|
+
)
|
1981
|
+
|
1982
|
+
def __call__(
|
1983
|
+
self,
|
1984
|
+
attn: Attention,
|
1985
|
+
hidden_states: torch.Tensor,
|
1986
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1987
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1988
|
+
temb: Optional[torch.Tensor] = None,
|
1989
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1990
|
+
) -> torch.Tensor:
|
1991
|
+
residual = hidden_states
|
1992
|
+
|
1993
|
+
if attn.spatial_norm is not None:
|
1994
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1995
|
+
|
1996
|
+
input_ndim = hidden_states.ndim
|
1997
|
+
|
1998
|
+
if input_ndim == 4:
|
1999
|
+
batch_size, channel, height, width = hidden_states.shape
|
2000
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2001
|
+
|
2002
|
+
batch_size, sequence_length, _ = (
|
2003
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2004
|
+
)
|
2005
|
+
|
2006
|
+
if attention_mask is not None:
|
2007
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2008
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2009
|
+
# (batch, heads, source_length, target_length)
|
2010
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2011
|
+
|
2012
|
+
if attn.group_norm is not None:
|
2013
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2014
|
+
|
2015
|
+
query = attn.to_q(hidden_states)
|
2016
|
+
|
2017
|
+
if encoder_hidden_states is None:
|
2018
|
+
encoder_hidden_states = hidden_states
|
2019
|
+
elif attn.norm_cross:
|
2020
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2021
|
+
|
2022
|
+
key = attn.to_k(encoder_hidden_states)
|
2023
|
+
value = attn.to_v(encoder_hidden_states)
|
2024
|
+
|
2025
|
+
inner_dim = key.shape[-1]
|
2026
|
+
head_dim = inner_dim // attn.heads
|
2027
|
+
|
2028
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2029
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2030
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2031
|
+
|
2032
|
+
# Apply RoPE if needed
|
2033
|
+
if image_rotary_emb is not None and not attn.is_cross_attention:
|
2034
|
+
from .embeddings import apply_rotary_emb_allegro
|
2035
|
+
|
2036
|
+
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
|
2037
|
+
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
|
2038
|
+
|
2039
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2040
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2041
|
+
hidden_states = F.scaled_dot_product_attention(
|
2042
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2043
|
+
)
|
2044
|
+
|
2045
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2046
|
+
hidden_states = hidden_states.to(query.dtype)
|
2047
|
+
|
2048
|
+
# linear proj
|
2049
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2050
|
+
# dropout
|
2051
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2052
|
+
|
2053
|
+
if input_ndim == 4:
|
2054
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2055
|
+
|
2056
|
+
if attn.residual_connection:
|
2057
|
+
hidden_states = hidden_states + residual
|
2058
|
+
|
2059
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2060
|
+
|
2061
|
+
return hidden_states
|
2062
|
+
|
2063
|
+
|
2064
|
+
class AuraFlowAttnProcessor2_0:
|
2065
|
+
"""Attention processor used typically in processing Aura Flow."""
|
2066
|
+
|
2067
|
+
def __init__(self):
|
2068
|
+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
2069
|
+
raise ImportError(
|
2070
|
+
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
2071
|
+
)
|
2072
|
+
|
2073
|
+
def __call__(
|
2074
|
+
self,
|
2075
|
+
attn: Attention,
|
2076
|
+
hidden_states: torch.FloatTensor,
|
2077
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2078
|
+
*args,
|
2079
|
+
**kwargs,
|
2080
|
+
) -> torch.FloatTensor:
|
2081
|
+
batch_size = hidden_states.shape[0]
|
2082
|
+
|
2083
|
+
# `sample` projections.
|
2084
|
+
query = attn.to_q(hidden_states)
|
2085
|
+
key = attn.to_k(hidden_states)
|
2086
|
+
value = attn.to_v(hidden_states)
|
2087
|
+
|
2088
|
+
# `context` projections.
|
2089
|
+
if encoder_hidden_states is not None:
|
2090
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2091
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1536
2092
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1537
2093
|
|
1538
2094
|
# Reshape.
|
@@ -1695,52 +2251,231 @@ class FusedAuraFlowAttnProcessor2_0:
|
|
1695
2251
|
return hidden_states
|
1696
2252
|
|
1697
2253
|
|
1698
|
-
|
1699
|
-
|
1700
|
-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
1701
|
-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
1702
|
-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
1703
|
-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
1704
|
-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
1705
|
-
|
1706
|
-
|
1707
|
-
class FluxSingleAttnProcessor2_0:
|
1708
|
-
r"""
|
1709
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1710
|
-
"""
|
2254
|
+
class FluxAttnProcessor2_0:
|
2255
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1711
2256
|
|
1712
2257
|
def __init__(self):
|
1713
2258
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1714
|
-
raise ImportError("
|
2259
|
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1715
2260
|
|
1716
2261
|
def __call__(
|
1717
2262
|
self,
|
1718
2263
|
attn: Attention,
|
1719
|
-
hidden_states: torch.
|
1720
|
-
encoder_hidden_states:
|
2264
|
+
hidden_states: torch.FloatTensor,
|
2265
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1721
2266
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1722
2267
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
1723
|
-
) -> torch.
|
1724
|
-
|
2268
|
+
) -> torch.FloatTensor:
|
2269
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1725
2270
|
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
2271
|
+
# `sample` projections.
|
2272
|
+
query = attn.to_q(hidden_states)
|
2273
|
+
key = attn.to_k(hidden_states)
|
2274
|
+
value = attn.to_v(hidden_states)
|
2275
|
+
|
2276
|
+
inner_dim = key.shape[-1]
|
2277
|
+
head_dim = inner_dim // attn.heads
|
2278
|
+
|
2279
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2280
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2281
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1729
2282
|
|
2283
|
+
if attn.norm_q is not None:
|
2284
|
+
query = attn.norm_q(query)
|
2285
|
+
if attn.norm_k is not None:
|
2286
|
+
key = attn.norm_k(key)
|
2287
|
+
|
2288
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2289
|
+
if encoder_hidden_states is not None:
|
2290
|
+
# `context` projections.
|
2291
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2292
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2293
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2294
|
+
|
2295
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2296
|
+
batch_size, -1, attn.heads, head_dim
|
2297
|
+
).transpose(1, 2)
|
2298
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2299
|
+
batch_size, -1, attn.heads, head_dim
|
2300
|
+
).transpose(1, 2)
|
2301
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2302
|
+
batch_size, -1, attn.heads, head_dim
|
2303
|
+
).transpose(1, 2)
|
2304
|
+
|
2305
|
+
if attn.norm_added_q is not None:
|
2306
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2307
|
+
if attn.norm_added_k is not None:
|
2308
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2309
|
+
|
2310
|
+
# attention
|
2311
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2312
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2313
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2314
|
+
|
2315
|
+
if image_rotary_emb is not None:
|
2316
|
+
from .embeddings import apply_rotary_emb
|
2317
|
+
|
2318
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2319
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2320
|
+
|
2321
|
+
hidden_states = F.scaled_dot_product_attention(
|
2322
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2323
|
+
)
|
2324
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2325
|
+
hidden_states = hidden_states.to(query.dtype)
|
2326
|
+
|
2327
|
+
if encoder_hidden_states is not None:
|
2328
|
+
encoder_hidden_states, hidden_states = (
|
2329
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2330
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2331
|
+
)
|
2332
|
+
|
2333
|
+
# linear proj
|
2334
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2335
|
+
# dropout
|
2336
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2337
|
+
|
2338
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2339
|
+
|
2340
|
+
return hidden_states, encoder_hidden_states
|
2341
|
+
else:
|
2342
|
+
return hidden_states
|
2343
|
+
|
2344
|
+
|
2345
|
+
class FluxAttnProcessor2_0_NPU:
|
2346
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2347
|
+
|
2348
|
+
def __init__(self):
|
2349
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2350
|
+
raise ImportError(
|
2351
|
+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
2352
|
+
)
|
2353
|
+
|
2354
|
+
def __call__(
|
2355
|
+
self,
|
2356
|
+
attn: Attention,
|
2357
|
+
hidden_states: torch.FloatTensor,
|
2358
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2359
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2360
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2361
|
+
) -> torch.FloatTensor:
|
1730
2362
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1731
2363
|
|
2364
|
+
# `sample` projections.
|
1732
2365
|
query = attn.to_q(hidden_states)
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
key = attn.to_k(encoder_hidden_states)
|
1737
|
-
value = attn.to_v(encoder_hidden_states)
|
2366
|
+
key = attn.to_k(hidden_states)
|
2367
|
+
value = attn.to_v(hidden_states)
|
1738
2368
|
|
1739
2369
|
inner_dim = key.shape[-1]
|
1740
2370
|
head_dim = inner_dim // attn.heads
|
1741
2371
|
|
1742
2372
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2373
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2374
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2375
|
+
|
2376
|
+
if attn.norm_q is not None:
|
2377
|
+
query = attn.norm_q(query)
|
2378
|
+
if attn.norm_k is not None:
|
2379
|
+
key = attn.norm_k(key)
|
2380
|
+
|
2381
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2382
|
+
if encoder_hidden_states is not None:
|
2383
|
+
# `context` projections.
|
2384
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2385
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2386
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2387
|
+
|
2388
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2389
|
+
batch_size, -1, attn.heads, head_dim
|
2390
|
+
).transpose(1, 2)
|
2391
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2392
|
+
batch_size, -1, attn.heads, head_dim
|
2393
|
+
).transpose(1, 2)
|
2394
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2395
|
+
batch_size, -1, attn.heads, head_dim
|
2396
|
+
).transpose(1, 2)
|
2397
|
+
|
2398
|
+
if attn.norm_added_q is not None:
|
2399
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2400
|
+
if attn.norm_added_k is not None:
|
2401
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2402
|
+
|
2403
|
+
# attention
|
2404
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2405
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2406
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2407
|
+
|
2408
|
+
if image_rotary_emb is not None:
|
2409
|
+
from .embeddings import apply_rotary_emb
|
2410
|
+
|
2411
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2412
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2413
|
+
|
2414
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
2415
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
2416
|
+
query,
|
2417
|
+
key,
|
2418
|
+
value,
|
2419
|
+
attn.heads,
|
2420
|
+
input_layout="BNSD",
|
2421
|
+
pse=None,
|
2422
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
2423
|
+
pre_tockens=65536,
|
2424
|
+
next_tockens=65536,
|
2425
|
+
keep_prob=1.0,
|
2426
|
+
sync=False,
|
2427
|
+
inner_precise=0,
|
2428
|
+
)[0]
|
2429
|
+
else:
|
2430
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2431
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2432
|
+
hidden_states = hidden_states.to(query.dtype)
|
2433
|
+
|
2434
|
+
if encoder_hidden_states is not None:
|
2435
|
+
encoder_hidden_states, hidden_states = (
|
2436
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2437
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2438
|
+
)
|
1743
2439
|
|
2440
|
+
# linear proj
|
2441
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2442
|
+
# dropout
|
2443
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2444
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2445
|
+
|
2446
|
+
return hidden_states, encoder_hidden_states
|
2447
|
+
else:
|
2448
|
+
return hidden_states
|
2449
|
+
|
2450
|
+
|
2451
|
+
class FusedFluxAttnProcessor2_0:
|
2452
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2453
|
+
|
2454
|
+
def __init__(self):
|
2455
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2456
|
+
raise ImportError(
|
2457
|
+
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2458
|
+
)
|
2459
|
+
|
2460
|
+
def __call__(
|
2461
|
+
self,
|
2462
|
+
attn: Attention,
|
2463
|
+
hidden_states: torch.FloatTensor,
|
2464
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2465
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2466
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2467
|
+
) -> torch.FloatTensor:
|
2468
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2469
|
+
|
2470
|
+
# `sample` projections.
|
2471
|
+
qkv = attn.to_qkv(hidden_states)
|
2472
|
+
split_size = qkv.shape[-1] // 3
|
2473
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2474
|
+
|
2475
|
+
inner_dim = key.shape[-1]
|
2476
|
+
head_dim = inner_dim // attn.heads
|
2477
|
+
|
2478
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1744
2479
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1745
2480
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1746
2481
|
|
@@ -1749,33 +2484,212 @@ class FluxSingleAttnProcessor2_0:
|
|
1749
2484
|
if attn.norm_k is not None:
|
1750
2485
|
key = attn.norm_k(key)
|
1751
2486
|
|
1752
|
-
#
|
2487
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2488
|
+
# `context` projections.
|
2489
|
+
if encoder_hidden_states is not None:
|
2490
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2491
|
+
split_size = encoder_qkv.shape[-1] // 3
|
2492
|
+
(
|
2493
|
+
encoder_hidden_states_query_proj,
|
2494
|
+
encoder_hidden_states_key_proj,
|
2495
|
+
encoder_hidden_states_value_proj,
|
2496
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2497
|
+
|
2498
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2499
|
+
batch_size, -1, attn.heads, head_dim
|
2500
|
+
).transpose(1, 2)
|
2501
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2502
|
+
batch_size, -1, attn.heads, head_dim
|
2503
|
+
).transpose(1, 2)
|
2504
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2505
|
+
batch_size, -1, attn.heads, head_dim
|
2506
|
+
).transpose(1, 2)
|
2507
|
+
|
2508
|
+
if attn.norm_added_q is not None:
|
2509
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2510
|
+
if attn.norm_added_k is not None:
|
2511
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2512
|
+
|
2513
|
+
# attention
|
2514
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2515
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2516
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2517
|
+
|
1753
2518
|
if image_rotary_emb is not None:
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
query, key = apply_rope(query, key, image_rotary_emb)
|
2519
|
+
from .embeddings import apply_rotary_emb
|
2520
|
+
|
2521
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2522
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
1759
2523
|
|
1760
|
-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1761
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
1762
2524
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2525
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2526
|
+
hidden_states = hidden_states.to(query.dtype)
|
2527
|
+
|
2528
|
+
if encoder_hidden_states is not None:
|
2529
|
+
encoder_hidden_states, hidden_states = (
|
2530
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2531
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2532
|
+
)
|
2533
|
+
|
2534
|
+
# linear proj
|
2535
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2536
|
+
# dropout
|
2537
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2538
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2539
|
+
|
2540
|
+
return hidden_states, encoder_hidden_states
|
2541
|
+
else:
|
2542
|
+
return hidden_states
|
2543
|
+
|
2544
|
+
|
2545
|
+
class FusedFluxAttnProcessor2_0_NPU:
|
2546
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2547
|
+
|
2548
|
+
def __init__(self):
|
2549
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2550
|
+
raise ImportError(
|
2551
|
+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
2552
|
+
)
|
2553
|
+
|
2554
|
+
def __call__(
|
2555
|
+
self,
|
2556
|
+
attn: Attention,
|
2557
|
+
hidden_states: torch.FloatTensor,
|
2558
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2559
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2560
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2561
|
+
) -> torch.FloatTensor:
|
2562
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2563
|
+
|
2564
|
+
# `sample` projections.
|
2565
|
+
qkv = attn.to_qkv(hidden_states)
|
2566
|
+
split_size = qkv.shape[-1] // 3
|
2567
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2568
|
+
|
2569
|
+
inner_dim = key.shape[-1]
|
2570
|
+
head_dim = inner_dim // attn.heads
|
2571
|
+
|
2572
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2573
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2574
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2575
|
+
|
2576
|
+
if attn.norm_q is not None:
|
2577
|
+
query = attn.norm_q(query)
|
2578
|
+
if attn.norm_k is not None:
|
2579
|
+
key = attn.norm_k(key)
|
2580
|
+
|
2581
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2582
|
+
# `context` projections.
|
2583
|
+
if encoder_hidden_states is not None:
|
2584
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2585
|
+
split_size = encoder_qkv.shape[-1] // 3
|
2586
|
+
(
|
2587
|
+
encoder_hidden_states_query_proj,
|
2588
|
+
encoder_hidden_states_key_proj,
|
2589
|
+
encoder_hidden_states_value_proj,
|
2590
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2591
|
+
|
2592
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2593
|
+
batch_size, -1, attn.heads, head_dim
|
2594
|
+
).transpose(1, 2)
|
2595
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2596
|
+
batch_size, -1, attn.heads, head_dim
|
2597
|
+
).transpose(1, 2)
|
2598
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2599
|
+
batch_size, -1, attn.heads, head_dim
|
2600
|
+
).transpose(1, 2)
|
2601
|
+
|
2602
|
+
if attn.norm_added_q is not None:
|
2603
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2604
|
+
if attn.norm_added_k is not None:
|
2605
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2606
|
+
|
2607
|
+
# attention
|
2608
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2609
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2610
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2611
|
+
|
2612
|
+
if image_rotary_emb is not None:
|
2613
|
+
from .embeddings import apply_rotary_emb
|
2614
|
+
|
2615
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2616
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2617
|
+
|
2618
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
2619
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
2620
|
+
query,
|
2621
|
+
key,
|
2622
|
+
value,
|
2623
|
+
attn.heads,
|
2624
|
+
input_layout="BNSD",
|
2625
|
+
pse=None,
|
2626
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
2627
|
+
pre_tockens=65536,
|
2628
|
+
next_tockens=65536,
|
2629
|
+
keep_prob=1.0,
|
2630
|
+
sync=False,
|
2631
|
+
inner_precise=0,
|
2632
|
+
)[0]
|
2633
|
+
else:
|
2634
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1763
2635
|
|
1764
2636
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1765
2637
|
hidden_states = hidden_states.to(query.dtype)
|
1766
2638
|
|
1767
|
-
if
|
1768
|
-
hidden_states =
|
2639
|
+
if encoder_hidden_states is not None:
|
2640
|
+
encoder_hidden_states, hidden_states = (
|
2641
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2642
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2643
|
+
)
|
1769
2644
|
|
1770
|
-
|
2645
|
+
# linear proj
|
2646
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2647
|
+
# dropout
|
2648
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2649
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2650
|
+
|
2651
|
+
return hidden_states, encoder_hidden_states
|
2652
|
+
else:
|
2653
|
+
return hidden_states
|
2654
|
+
|
2655
|
+
|
2656
|
+
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
2657
|
+
"""Flux Attention processor for IP-Adapter."""
|
2658
|
+
|
2659
|
+
def __init__(
|
2660
|
+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
2661
|
+
):
|
2662
|
+
super().__init__()
|
2663
|
+
|
2664
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2665
|
+
raise ImportError(
|
2666
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2667
|
+
)
|
2668
|
+
|
2669
|
+
self.hidden_size = hidden_size
|
2670
|
+
self.cross_attention_dim = cross_attention_dim
|
1771
2671
|
|
2672
|
+
if not isinstance(num_tokens, (tuple, list)):
|
2673
|
+
num_tokens = [num_tokens]
|
1772
2674
|
|
1773
|
-
|
1774
|
-
|
2675
|
+
if not isinstance(scale, list):
|
2676
|
+
scale = [scale] * len(num_tokens)
|
2677
|
+
if len(scale) != len(num_tokens):
|
2678
|
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
2679
|
+
self.scale = scale
|
1775
2680
|
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
2681
|
+
self.to_k_ip = nn.ModuleList(
|
2682
|
+
[
|
2683
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2684
|
+
for _ in range(len(num_tokens))
|
2685
|
+
]
|
2686
|
+
)
|
2687
|
+
self.to_v_ip = nn.ModuleList(
|
2688
|
+
[
|
2689
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2690
|
+
for _ in range(len(num_tokens))
|
2691
|
+
]
|
2692
|
+
)
|
1779
2693
|
|
1780
2694
|
def __call__(
|
1781
2695
|
self,
|
@@ -1784,88 +2698,102 @@ class FluxAttnProcessor2_0:
|
|
1784
2698
|
encoder_hidden_states: torch.FloatTensor = None,
|
1785
2699
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1786
2700
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
2701
|
+
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
2702
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
1787
2703
|
) -> torch.FloatTensor:
|
1788
|
-
|
1789
|
-
if input_ndim == 4:
|
1790
|
-
batch_size, channel, height, width = hidden_states.shape
|
1791
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1792
|
-
context_input_ndim = encoder_hidden_states.ndim
|
1793
|
-
if context_input_ndim == 4:
|
1794
|
-
batch_size, channel, height, width = encoder_hidden_states.shape
|
1795
|
-
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1796
|
-
|
1797
|
-
batch_size = encoder_hidden_states.shape[0]
|
2704
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1798
2705
|
|
1799
2706
|
# `sample` projections.
|
1800
|
-
|
2707
|
+
hidden_states_query_proj = attn.to_q(hidden_states)
|
1801
2708
|
key = attn.to_k(hidden_states)
|
1802
2709
|
value = attn.to_v(hidden_states)
|
1803
2710
|
|
1804
2711
|
inner_dim = key.shape[-1]
|
1805
2712
|
head_dim = inner_dim // attn.heads
|
1806
2713
|
|
1807
|
-
|
2714
|
+
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1808
2715
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1809
2716
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1810
2717
|
|
1811
2718
|
if attn.norm_q is not None:
|
1812
|
-
|
2719
|
+
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
1813
2720
|
if attn.norm_k is not None:
|
1814
2721
|
key = attn.norm_k(key)
|
1815
2722
|
|
1816
|
-
# `
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
2723
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2724
|
+
if encoder_hidden_states is not None:
|
2725
|
+
# `context` projections.
|
2726
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2727
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2728
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1820
2729
|
|
1821
|
-
|
1822
|
-
|
1823
|
-
|
1824
|
-
|
1825
|
-
|
1826
|
-
|
1827
|
-
|
1828
|
-
|
1829
|
-
|
2730
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2731
|
+
batch_size, -1, attn.heads, head_dim
|
2732
|
+
).transpose(1, 2)
|
2733
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2734
|
+
batch_size, -1, attn.heads, head_dim
|
2735
|
+
).transpose(1, 2)
|
2736
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2737
|
+
batch_size, -1, attn.heads, head_dim
|
2738
|
+
).transpose(1, 2)
|
1830
2739
|
|
1831
|
-
|
1832
|
-
|
1833
|
-
|
1834
|
-
|
2740
|
+
if attn.norm_added_q is not None:
|
2741
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2742
|
+
if attn.norm_added_k is not None:
|
2743
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1835
2744
|
|
1836
|
-
|
1837
|
-
|
1838
|
-
|
1839
|
-
|
2745
|
+
# attention
|
2746
|
+
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
2747
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2748
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1840
2749
|
|
1841
2750
|
if image_rotary_emb is not None:
|
1842
|
-
|
1843
|
-
|
1844
|
-
|
1845
|
-
|
1846
|
-
query, key = apply_rope(query, key, image_rotary_emb)
|
2751
|
+
from .embeddings import apply_rotary_emb
|
2752
|
+
|
2753
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2754
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
1847
2755
|
|
1848
2756
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1849
2757
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1850
2758
|
hidden_states = hidden_states.to(query.dtype)
|
1851
2759
|
|
1852
|
-
encoder_hidden_states
|
1853
|
-
hidden_states
|
1854
|
-
|
1855
|
-
|
2760
|
+
if encoder_hidden_states is not None:
|
2761
|
+
encoder_hidden_states, hidden_states = (
|
2762
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2763
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2764
|
+
)
|
1856
2765
|
|
1857
|
-
|
1858
|
-
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
2766
|
+
# linear proj
|
2767
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2768
|
+
# dropout
|
2769
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2770
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1862
2771
|
|
1863
|
-
|
1864
|
-
|
1865
|
-
|
1866
|
-
|
2772
|
+
# IP-adapter
|
2773
|
+
ip_query = hidden_states_query_proj
|
2774
|
+
ip_attn_output = None
|
2775
|
+
# for ip-adapter
|
2776
|
+
# TODO: support for multiple adapters
|
2777
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2778
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2779
|
+
):
|
2780
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
2781
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
2782
|
+
|
2783
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2784
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2785
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2786
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2787
|
+
ip_attn_output = F.scaled_dot_product_attention(
|
2788
|
+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2789
|
+
)
|
2790
|
+
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2791
|
+
ip_attn_output = scale * ip_attn_output
|
2792
|
+
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
1867
2793
|
|
1868
|
-
|
2794
|
+
return hidden_states, encoder_hidden_states, ip_attn_output
|
2795
|
+
else:
|
2796
|
+
return hidden_states
|
1869
2797
|
|
1870
2798
|
|
1871
2799
|
class CogVideoXAttnProcessor2_0:
|
@@ -2260,7 +3188,217 @@ class AttnProcessorNPU:
|
|
2260
3188
|
inner_precise=0,
|
2261
3189
|
)[0]
|
2262
3190
|
else:
|
2263
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
3191
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3192
|
+
hidden_states = F.scaled_dot_product_attention(
|
3193
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
3194
|
+
)
|
3195
|
+
|
3196
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3197
|
+
hidden_states = hidden_states.to(query.dtype)
|
3198
|
+
|
3199
|
+
# linear proj
|
3200
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3201
|
+
# dropout
|
3202
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3203
|
+
|
3204
|
+
if input_ndim == 4:
|
3205
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
3206
|
+
|
3207
|
+
if attn.residual_connection:
|
3208
|
+
hidden_states = hidden_states + residual
|
3209
|
+
|
3210
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3211
|
+
|
3212
|
+
return hidden_states
|
3213
|
+
|
3214
|
+
|
3215
|
+
class AttnProcessor2_0:
|
3216
|
+
r"""
|
3217
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
3218
|
+
"""
|
3219
|
+
|
3220
|
+
def __init__(self):
|
3221
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
3222
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
3223
|
+
|
3224
|
+
def __call__(
|
3225
|
+
self,
|
3226
|
+
attn: Attention,
|
3227
|
+
hidden_states: torch.Tensor,
|
3228
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3229
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3230
|
+
temb: Optional[torch.Tensor] = None,
|
3231
|
+
*args,
|
3232
|
+
**kwargs,
|
3233
|
+
) -> torch.Tensor:
|
3234
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3235
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3236
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
3237
|
+
|
3238
|
+
residual = hidden_states
|
3239
|
+
if attn.spatial_norm is not None:
|
3240
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
3241
|
+
|
3242
|
+
input_ndim = hidden_states.ndim
|
3243
|
+
|
3244
|
+
if input_ndim == 4:
|
3245
|
+
batch_size, channel, height, width = hidden_states.shape
|
3246
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3247
|
+
|
3248
|
+
batch_size, sequence_length, _ = (
|
3249
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3250
|
+
)
|
3251
|
+
|
3252
|
+
if attention_mask is not None:
|
3253
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3254
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
3255
|
+
# (batch, heads, source_length, target_length)
|
3256
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3257
|
+
|
3258
|
+
if attn.group_norm is not None:
|
3259
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
3260
|
+
|
3261
|
+
query = attn.to_q(hidden_states)
|
3262
|
+
|
3263
|
+
if encoder_hidden_states is None:
|
3264
|
+
encoder_hidden_states = hidden_states
|
3265
|
+
elif attn.norm_cross:
|
3266
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
3267
|
+
|
3268
|
+
key = attn.to_k(encoder_hidden_states)
|
3269
|
+
value = attn.to_v(encoder_hidden_states)
|
3270
|
+
|
3271
|
+
inner_dim = key.shape[-1]
|
3272
|
+
head_dim = inner_dim // attn.heads
|
3273
|
+
|
3274
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3275
|
+
|
3276
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3277
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3278
|
+
|
3279
|
+
if attn.norm_q is not None:
|
3280
|
+
query = attn.norm_q(query)
|
3281
|
+
if attn.norm_k is not None:
|
3282
|
+
key = attn.norm_k(key)
|
3283
|
+
|
3284
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
3285
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3286
|
+
hidden_states = F.scaled_dot_product_attention(
|
3287
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
3288
|
+
)
|
3289
|
+
|
3290
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3291
|
+
hidden_states = hidden_states.to(query.dtype)
|
3292
|
+
|
3293
|
+
# linear proj
|
3294
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3295
|
+
# dropout
|
3296
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3297
|
+
|
3298
|
+
if input_ndim == 4:
|
3299
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
3300
|
+
|
3301
|
+
if attn.residual_connection:
|
3302
|
+
hidden_states = hidden_states + residual
|
3303
|
+
|
3304
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3305
|
+
|
3306
|
+
return hidden_states
|
3307
|
+
|
3308
|
+
|
3309
|
+
class XLAFlashAttnProcessor2_0:
|
3310
|
+
r"""
|
3311
|
+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
3312
|
+
"""
|
3313
|
+
|
3314
|
+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
3315
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
3316
|
+
raise ImportError(
|
3317
|
+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3318
|
+
)
|
3319
|
+
if is_torch_xla_version("<", "2.3"):
|
3320
|
+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
3321
|
+
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
3322
|
+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
3323
|
+
self.partition_spec = partition_spec
|
3324
|
+
|
3325
|
+
def __call__(
|
3326
|
+
self,
|
3327
|
+
attn: Attention,
|
3328
|
+
hidden_states: torch.Tensor,
|
3329
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3330
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3331
|
+
temb: Optional[torch.Tensor] = None,
|
3332
|
+
*args,
|
3333
|
+
**kwargs,
|
3334
|
+
) -> torch.Tensor:
|
3335
|
+
residual = hidden_states
|
3336
|
+
if attn.spatial_norm is not None:
|
3337
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
3338
|
+
|
3339
|
+
input_ndim = hidden_states.ndim
|
3340
|
+
|
3341
|
+
if input_ndim == 4:
|
3342
|
+
batch_size, channel, height, width = hidden_states.shape
|
3343
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3344
|
+
|
3345
|
+
batch_size, sequence_length, _ = (
|
3346
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3347
|
+
)
|
3348
|
+
|
3349
|
+
if attention_mask is not None:
|
3350
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3351
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
3352
|
+
# (batch, heads, source_length, target_length)
|
3353
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3354
|
+
|
3355
|
+
if attn.group_norm is not None:
|
3356
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
3357
|
+
|
3358
|
+
query = attn.to_q(hidden_states)
|
3359
|
+
|
3360
|
+
if encoder_hidden_states is None:
|
3361
|
+
encoder_hidden_states = hidden_states
|
3362
|
+
elif attn.norm_cross:
|
3363
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
3364
|
+
|
3365
|
+
key = attn.to_k(encoder_hidden_states)
|
3366
|
+
value = attn.to_v(encoder_hidden_states)
|
3367
|
+
|
3368
|
+
inner_dim = key.shape[-1]
|
3369
|
+
head_dim = inner_dim // attn.heads
|
3370
|
+
|
3371
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3372
|
+
|
3373
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3374
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3375
|
+
|
3376
|
+
if attn.norm_q is not None:
|
3377
|
+
query = attn.norm_q(query)
|
3378
|
+
if attn.norm_k is not None:
|
3379
|
+
key = attn.norm_k(key)
|
3380
|
+
|
3381
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
3382
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3383
|
+
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
|
3384
|
+
if attention_mask is not None:
|
3385
|
+
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
3386
|
+
# Convert mask to float and replace 0s with -inf and 1s with 0
|
3387
|
+
attention_mask = (
|
3388
|
+
attention_mask.float()
|
3389
|
+
.masked_fill(attention_mask == 0, float("-inf"))
|
3390
|
+
.masked_fill(attention_mask == 1, float(0.0))
|
3391
|
+
)
|
3392
|
+
|
3393
|
+
# Apply attention mask to key
|
3394
|
+
key = key + attention_mask
|
3395
|
+
query /= math.sqrt(query.shape[3])
|
3396
|
+
partition_spec = self.partition_spec if is_spmd() else None
|
3397
|
+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
|
3398
|
+
else:
|
3399
|
+
logger.warning(
|
3400
|
+
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
|
3401
|
+
)
|
2264
3402
|
hidden_states = F.scaled_dot_product_attention(
|
2265
3403
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2266
3404
|
)
|
@@ -2284,9 +3422,9 @@ class AttnProcessorNPU:
|
|
2284
3422
|
return hidden_states
|
2285
3423
|
|
2286
3424
|
|
2287
|
-
class
|
3425
|
+
class MochiVaeAttnProcessor2_0:
|
2288
3426
|
r"""
|
2289
|
-
|
3427
|
+
Attention processor used in Mochi VAE.
|
2290
3428
|
"""
|
2291
3429
|
|
2292
3430
|
def __init__(self):
|
@@ -2299,23 +3437,9 @@ class AttnProcessor2_0:
|
|
2299
3437
|
hidden_states: torch.Tensor,
|
2300
3438
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2301
3439
|
attention_mask: Optional[torch.Tensor] = None,
|
2302
|
-
temb: Optional[torch.Tensor] = None,
|
2303
|
-
*args,
|
2304
|
-
**kwargs,
|
2305
3440
|
) -> torch.Tensor:
|
2306
|
-
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2307
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2308
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
2309
|
-
|
2310
3441
|
residual = hidden_states
|
2311
|
-
|
2312
|
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2313
|
-
|
2314
|
-
input_ndim = hidden_states.ndim
|
2315
|
-
|
2316
|
-
if input_ndim == 4:
|
2317
|
-
batch_size, channel, height, width = hidden_states.shape
|
2318
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3442
|
+
is_single_frame = hidden_states.shape[1] == 1
|
2319
3443
|
|
2320
3444
|
batch_size, sequence_length, _ = (
|
2321
3445
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
@@ -2327,15 +3451,24 @@ class AttnProcessor2_0:
|
|
2327
3451
|
# (batch, heads, source_length, target_length)
|
2328
3452
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2329
3453
|
|
2330
|
-
if
|
2331
|
-
hidden_states = attn.
|
3454
|
+
if is_single_frame:
|
3455
|
+
hidden_states = attn.to_v(hidden_states)
|
3456
|
+
|
3457
|
+
# linear proj
|
3458
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3459
|
+
# dropout
|
3460
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3461
|
+
|
3462
|
+
if attn.residual_connection:
|
3463
|
+
hidden_states = hidden_states + residual
|
3464
|
+
|
3465
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3466
|
+
return hidden_states
|
2332
3467
|
|
2333
3468
|
query = attn.to_q(hidden_states)
|
2334
3469
|
|
2335
3470
|
if encoder_hidden_states is None:
|
2336
3471
|
encoder_hidden_states = hidden_states
|
2337
|
-
elif attn.norm_cross:
|
2338
|
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2339
3472
|
|
2340
3473
|
key = attn.to_k(encoder_hidden_states)
|
2341
3474
|
value = attn.to_v(encoder_hidden_states)
|
@@ -2344,7 +3477,6 @@ class AttnProcessor2_0:
|
|
2344
3477
|
head_dim = inner_dim // attn.heads
|
2345
3478
|
|
2346
3479
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2347
|
-
|
2348
3480
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2349
3481
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2350
3482
|
|
@@ -2356,7 +3488,7 @@ class AttnProcessor2_0:
|
|
2356
3488
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2357
3489
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
2358
3490
|
hidden_states = F.scaled_dot_product_attention(
|
2359
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=
|
3491
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
|
2360
3492
|
)
|
2361
3493
|
|
2362
3494
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
@@ -2367,9 +3499,6 @@ class AttnProcessor2_0:
|
|
2367
3499
|
# dropout
|
2368
3500
|
hidden_states = attn.to_out[1](hidden_states)
|
2369
3501
|
|
2370
|
-
if input_ndim == 4:
|
2371
|
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2372
|
-
|
2373
3502
|
if attn.residual_connection:
|
2374
3503
|
hidden_states = hidden_states + residual
|
2375
3504
|
|
@@ -3572,34 +4701,232 @@ class SpatialNorm(nn.Module):
|
|
3572
4701
|
"""
|
3573
4702
|
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
|
3574
4703
|
|
3575
|
-
Args:
|
3576
|
-
f_channels (`int`):
|
3577
|
-
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
3578
|
-
zq_channels (`int`):
|
3579
|
-
The number of channels for the quantized vector as described in the paper.
|
3580
|
-
"""
|
4704
|
+
Args:
|
4705
|
+
f_channels (`int`):
|
4706
|
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
4707
|
+
zq_channels (`int`):
|
4708
|
+
The number of channels for the quantized vector as described in the paper.
|
4709
|
+
"""
|
4710
|
+
|
4711
|
+
def __init__(
|
4712
|
+
self,
|
4713
|
+
f_channels: int,
|
4714
|
+
zq_channels: int,
|
4715
|
+
):
|
4716
|
+
super().__init__()
|
4717
|
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
4718
|
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4719
|
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4720
|
+
|
4721
|
+
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
4722
|
+
f_size = f.shape[-2:]
|
4723
|
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
4724
|
+
norm_f = self.norm_layer(f)
|
4725
|
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
4726
|
+
return new_f
|
4727
|
+
|
4728
|
+
|
4729
|
+
class IPAdapterAttnProcessor(nn.Module):
|
4730
|
+
r"""
|
4731
|
+
Attention processor for Multiple IP-Adapters.
|
4732
|
+
|
4733
|
+
Args:
|
4734
|
+
hidden_size (`int`):
|
4735
|
+
The hidden size of the attention layer.
|
4736
|
+
cross_attention_dim (`int`):
|
4737
|
+
The number of channels in the `encoder_hidden_states`.
|
4738
|
+
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
4739
|
+
The context length of the image features.
|
4740
|
+
scale (`float` or List[`float`], defaults to 1.0):
|
4741
|
+
the weight scale of image prompt.
|
4742
|
+
"""
|
4743
|
+
|
4744
|
+
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
4745
|
+
super().__init__()
|
4746
|
+
|
4747
|
+
self.hidden_size = hidden_size
|
4748
|
+
self.cross_attention_dim = cross_attention_dim
|
4749
|
+
|
4750
|
+
if not isinstance(num_tokens, (tuple, list)):
|
4751
|
+
num_tokens = [num_tokens]
|
4752
|
+
self.num_tokens = num_tokens
|
4753
|
+
|
4754
|
+
if not isinstance(scale, list):
|
4755
|
+
scale = [scale] * len(num_tokens)
|
4756
|
+
if len(scale) != len(num_tokens):
|
4757
|
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
4758
|
+
self.scale = scale
|
4759
|
+
|
4760
|
+
self.to_k_ip = nn.ModuleList(
|
4761
|
+
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
4762
|
+
)
|
4763
|
+
self.to_v_ip = nn.ModuleList(
|
4764
|
+
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
4765
|
+
)
|
4766
|
+
|
4767
|
+
def __call__(
|
4768
|
+
self,
|
4769
|
+
attn: Attention,
|
4770
|
+
hidden_states: torch.Tensor,
|
4771
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
4772
|
+
attention_mask: Optional[torch.Tensor] = None,
|
4773
|
+
temb: Optional[torch.Tensor] = None,
|
4774
|
+
scale: float = 1.0,
|
4775
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
4776
|
+
):
|
4777
|
+
residual = hidden_states
|
4778
|
+
|
4779
|
+
# separate ip_hidden_states from encoder_hidden_states
|
4780
|
+
if encoder_hidden_states is not None:
|
4781
|
+
if isinstance(encoder_hidden_states, tuple):
|
4782
|
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
4783
|
+
else:
|
4784
|
+
deprecation_message = (
|
4785
|
+
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
4786
|
+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
4787
|
+
)
|
4788
|
+
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
4789
|
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
4790
|
+
encoder_hidden_states, ip_hidden_states = (
|
4791
|
+
encoder_hidden_states[:, :end_pos, :],
|
4792
|
+
[encoder_hidden_states[:, end_pos:, :]],
|
4793
|
+
)
|
4794
|
+
|
4795
|
+
if attn.spatial_norm is not None:
|
4796
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
4797
|
+
|
4798
|
+
input_ndim = hidden_states.ndim
|
4799
|
+
|
4800
|
+
if input_ndim == 4:
|
4801
|
+
batch_size, channel, height, width = hidden_states.shape
|
4802
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
4803
|
+
|
4804
|
+
batch_size, sequence_length, _ = (
|
4805
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
4806
|
+
)
|
4807
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
4808
|
+
|
4809
|
+
if attn.group_norm is not None:
|
4810
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
4811
|
+
|
4812
|
+
query = attn.to_q(hidden_states)
|
4813
|
+
|
4814
|
+
if encoder_hidden_states is None:
|
4815
|
+
encoder_hidden_states = hidden_states
|
4816
|
+
elif attn.norm_cross:
|
4817
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
4818
|
+
|
4819
|
+
key = attn.to_k(encoder_hidden_states)
|
4820
|
+
value = attn.to_v(encoder_hidden_states)
|
4821
|
+
|
4822
|
+
query = attn.head_to_batch_dim(query)
|
4823
|
+
key = attn.head_to_batch_dim(key)
|
4824
|
+
value = attn.head_to_batch_dim(value)
|
4825
|
+
|
4826
|
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
4827
|
+
hidden_states = torch.bmm(attention_probs, value)
|
4828
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
4829
|
+
|
4830
|
+
if ip_adapter_masks is not None:
|
4831
|
+
if not isinstance(ip_adapter_masks, List):
|
4832
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
4833
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
4834
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
4835
|
+
raise ValueError(
|
4836
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
4837
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
4838
|
+
f"({len(ip_hidden_states)})"
|
4839
|
+
)
|
4840
|
+
else:
|
4841
|
+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
4842
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
4843
|
+
raise ValueError(
|
4844
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
4845
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
4846
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
4847
|
+
)
|
4848
|
+
if mask.shape[1] != ip_state.shape[1]:
|
4849
|
+
raise ValueError(
|
4850
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
4851
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
4852
|
+
)
|
4853
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
4854
|
+
raise ValueError(
|
4855
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
4856
|
+
f"number of scales ({len(scale)}) at index {index}"
|
4857
|
+
)
|
4858
|
+
else:
|
4859
|
+
ip_adapter_masks = [None] * len(self.scale)
|
4860
|
+
|
4861
|
+
# for ip-adapter
|
4862
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
4863
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
4864
|
+
):
|
4865
|
+
skip = False
|
4866
|
+
if isinstance(scale, list):
|
4867
|
+
if all(s == 0 for s in scale):
|
4868
|
+
skip = True
|
4869
|
+
elif scale == 0:
|
4870
|
+
skip = True
|
4871
|
+
if not skip:
|
4872
|
+
if mask is not None:
|
4873
|
+
if not isinstance(scale, list):
|
4874
|
+
scale = [scale] * mask.shape[1]
|
4875
|
+
|
4876
|
+
current_num_images = mask.shape[1]
|
4877
|
+
for i in range(current_num_images):
|
4878
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
4879
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
4880
|
+
|
4881
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
4882
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
4883
|
+
|
4884
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
4885
|
+
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
4886
|
+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
4887
|
+
|
4888
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
4889
|
+
mask[:, i, :, :],
|
4890
|
+
batch_size,
|
4891
|
+
_current_ip_hidden_states.shape[1],
|
4892
|
+
_current_ip_hidden_states.shape[2],
|
4893
|
+
)
|
4894
|
+
|
4895
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
4896
|
+
|
4897
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
4898
|
+
else:
|
4899
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
4900
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
4901
|
+
|
4902
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
4903
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
4904
|
+
|
4905
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
4906
|
+
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
4907
|
+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
4908
|
+
|
4909
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
4910
|
+
|
4911
|
+
# linear proj
|
4912
|
+
hidden_states = attn.to_out[0](hidden_states)
|
4913
|
+
# dropout
|
4914
|
+
hidden_states = attn.to_out[1](hidden_states)
|
4915
|
+
|
4916
|
+
if input_ndim == 4:
|
4917
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
3581
4918
|
|
3582
|
-
|
3583
|
-
|
3584
|
-
f_channels: int,
|
3585
|
-
zq_channels: int,
|
3586
|
-
):
|
3587
|
-
super().__init__()
|
3588
|
-
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
3589
|
-
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
3590
|
-
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4919
|
+
if attn.residual_connection:
|
4920
|
+
hidden_states = hidden_states + residual
|
3591
4921
|
|
3592
|
-
|
3593
|
-
|
3594
|
-
|
3595
|
-
norm_f = self.norm_layer(f)
|
3596
|
-
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
3597
|
-
return new_f
|
4922
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
4923
|
+
|
4924
|
+
return hidden_states
|
3598
4925
|
|
3599
4926
|
|
3600
|
-
class
|
4927
|
+
class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
3601
4928
|
r"""
|
3602
|
-
Attention processor for
|
4929
|
+
Attention processor for IP-Adapter for PyTorch 2.0.
|
3603
4930
|
|
3604
4931
|
Args:
|
3605
4932
|
hidden_size (`int`):
|
@@ -3608,13 +4935,18 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3608
4935
|
The number of channels in the `encoder_hidden_states`.
|
3609
4936
|
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
3610
4937
|
The context length of the image features.
|
3611
|
-
scale (`float` or List[
|
4938
|
+
scale (`float` or `List[float]`, defaults to 1.0):
|
3612
4939
|
the weight scale of image prompt.
|
3613
4940
|
"""
|
3614
4941
|
|
3615
4942
|
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
3616
4943
|
super().__init__()
|
3617
4944
|
|
4945
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
4946
|
+
raise ImportError(
|
4947
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
4948
|
+
)
|
4949
|
+
|
3618
4950
|
self.hidden_size = hidden_size
|
3619
4951
|
self.cross_attention_dim = cross_attention_dim
|
3620
4952
|
|
@@ -3675,7 +5007,12 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3675
5007
|
batch_size, sequence_length, _ = (
|
3676
5008
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3677
5009
|
)
|
3678
|
-
|
5010
|
+
|
5011
|
+
if attention_mask is not None:
|
5012
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
5013
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
5014
|
+
# (batch, heads, source_length, target_length)
|
5015
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3679
5016
|
|
3680
5017
|
if attn.group_norm is not None:
|
3681
5018
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
@@ -3690,13 +5027,22 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3690
5027
|
key = attn.to_k(encoder_hidden_states)
|
3691
5028
|
value = attn.to_v(encoder_hidden_states)
|
3692
5029
|
|
3693
|
-
|
3694
|
-
|
3695
|
-
value = attn.head_to_batch_dim(value)
|
5030
|
+
inner_dim = key.shape[-1]
|
5031
|
+
head_dim = inner_dim // attn.heads
|
3696
5032
|
|
3697
|
-
|
3698
|
-
|
3699
|
-
|
5033
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5034
|
+
|
5035
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5036
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5037
|
+
|
5038
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5039
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5040
|
+
hidden_states = F.scaled_dot_product_attention(
|
5041
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
5042
|
+
)
|
5043
|
+
|
5044
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
5045
|
+
hidden_states = hidden_states.to(query.dtype)
|
3700
5046
|
|
3701
5047
|
if ip_adapter_masks is not None:
|
3702
5048
|
if not isinstance(ip_adapter_masks, List):
|
@@ -3749,12 +5095,19 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3749
5095
|
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
3750
5096
|
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
3751
5097
|
|
3752
|
-
ip_key = attn.
|
3753
|
-
ip_value = attn.
|
5098
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5099
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3754
5100
|
|
3755
|
-
|
3756
|
-
|
3757
|
-
_current_ip_hidden_states =
|
5101
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5102
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5103
|
+
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
5104
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
5105
|
+
)
|
5106
|
+
|
5107
|
+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
5108
|
+
batch_size, -1, attn.heads * head_dim
|
5109
|
+
)
|
5110
|
+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
3758
5111
|
|
3759
5112
|
mask_downsample = IPAdapterMaskProcessor.downsample(
|
3760
5113
|
mask[:, i, :, :],
|
@@ -3764,18 +5117,24 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3764
5117
|
)
|
3765
5118
|
|
3766
5119
|
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
3767
|
-
|
3768
5120
|
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
3769
5121
|
else:
|
3770
5122
|
ip_key = to_k_ip(current_ip_hidden_states)
|
3771
5123
|
ip_value = to_v_ip(current_ip_hidden_states)
|
3772
5124
|
|
3773
|
-
ip_key = attn.
|
3774
|
-
ip_value = attn.
|
5125
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5126
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3775
5127
|
|
3776
|
-
|
3777
|
-
|
3778
|
-
current_ip_hidden_states =
|
5128
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5129
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5130
|
+
current_ip_hidden_states = F.scaled_dot_product_attention(
|
5131
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
5132
|
+
)
|
5133
|
+
|
5134
|
+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
5135
|
+
batch_size, -1, attn.heads * head_dim
|
5136
|
+
)
|
5137
|
+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
3779
5138
|
|
3780
5139
|
hidden_states = hidden_states + scale * current_ip_hidden_states
|
3781
5140
|
|
@@ -3795,9 +5154,9 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3795
5154
|
return hidden_states
|
3796
5155
|
|
3797
5156
|
|
3798
|
-
class
|
5157
|
+
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
3799
5158
|
r"""
|
3800
|
-
Attention processor for IP-Adapter
|
5159
|
+
Attention processor for IP-Adapter using xFormers.
|
3801
5160
|
|
3802
5161
|
Args:
|
3803
5162
|
hidden_size (`int`):
|
@@ -3808,18 +5167,26 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3808
5167
|
The context length of the image features.
|
3809
5168
|
scale (`float` or `List[float]`, defaults to 1.0):
|
3810
5169
|
the weight scale of image prompt.
|
5170
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
5171
|
+
The base
|
5172
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
5173
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
5174
|
+
operator.
|
3811
5175
|
"""
|
3812
5176
|
|
3813
|
-
def __init__(
|
5177
|
+
def __init__(
|
5178
|
+
self,
|
5179
|
+
hidden_size,
|
5180
|
+
cross_attention_dim=None,
|
5181
|
+
num_tokens=(4,),
|
5182
|
+
scale=1.0,
|
5183
|
+
attention_op: Optional[Callable] = None,
|
5184
|
+
):
|
3814
5185
|
super().__init__()
|
3815
5186
|
|
3816
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
3817
|
-
raise ImportError(
|
3818
|
-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3819
|
-
)
|
3820
|
-
|
3821
5187
|
self.hidden_size = hidden_size
|
3822
5188
|
self.cross_attention_dim = cross_attention_dim
|
5189
|
+
self.attention_op = attention_op
|
3823
5190
|
|
3824
5191
|
if not isinstance(num_tokens, (tuple, list)):
|
3825
5192
|
num_tokens = [num_tokens]
|
@@ -3832,21 +5199,21 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3832
5199
|
self.scale = scale
|
3833
5200
|
|
3834
5201
|
self.to_k_ip = nn.ModuleList(
|
3835
|
-
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
5202
|
+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
3836
5203
|
)
|
3837
5204
|
self.to_v_ip = nn.ModuleList(
|
3838
|
-
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
5205
|
+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
3839
5206
|
)
|
3840
5207
|
|
3841
5208
|
def __call__(
|
3842
5209
|
self,
|
3843
5210
|
attn: Attention,
|
3844
|
-
hidden_states: torch.
|
3845
|
-
encoder_hidden_states: Optional[torch.
|
3846
|
-
attention_mask: Optional[torch.
|
3847
|
-
temb: Optional[torch.
|
5211
|
+
hidden_states: torch.FloatTensor,
|
5212
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
5213
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
5214
|
+
temb: Optional[torch.FloatTensor] = None,
|
3848
5215
|
scale: float = 1.0,
|
3849
|
-
ip_adapter_masks: Optional[torch.
|
5216
|
+
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
3850
5217
|
):
|
3851
5218
|
residual = hidden_states
|
3852
5219
|
|
@@ -3881,9 +5248,14 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3881
5248
|
|
3882
5249
|
if attention_mask is not None:
|
3883
5250
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3884
|
-
#
|
3885
|
-
#
|
3886
|
-
|
5251
|
+
# expand our mask's singleton query_tokens dimension:
|
5252
|
+
# [batch*heads, 1, key_tokens] ->
|
5253
|
+
# [batch*heads, query_tokens, key_tokens]
|
5254
|
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
5255
|
+
# [batch*heads, query_tokens, key_tokens]
|
5256
|
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
5257
|
+
_, query_tokens, _ = hidden_states.shape
|
5258
|
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
3887
5259
|
|
3888
5260
|
if attn.group_norm is not None:
|
3889
5261
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
@@ -3898,131 +5270,291 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3898
5270
|
key = attn.to_k(encoder_hidden_states)
|
3899
5271
|
value = attn.to_v(encoder_hidden_states)
|
3900
5272
|
|
5273
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
5274
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
5275
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
5276
|
+
|
5277
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
5278
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
5279
|
+
)
|
5280
|
+
hidden_states = hidden_states.to(query.dtype)
|
5281
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
5282
|
+
|
5283
|
+
if ip_hidden_states:
|
5284
|
+
if ip_adapter_masks is not None:
|
5285
|
+
if not isinstance(ip_adapter_masks, List):
|
5286
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
5287
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
5288
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
5289
|
+
raise ValueError(
|
5290
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
5291
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
5292
|
+
f"({len(ip_hidden_states)})"
|
5293
|
+
)
|
5294
|
+
else:
|
5295
|
+
for index, (mask, scale, ip_state) in enumerate(
|
5296
|
+
zip(ip_adapter_masks, self.scale, ip_hidden_states)
|
5297
|
+
):
|
5298
|
+
if mask is None:
|
5299
|
+
continue
|
5300
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
5301
|
+
raise ValueError(
|
5302
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
5303
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
5304
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
5305
|
+
)
|
5306
|
+
if mask.shape[1] != ip_state.shape[1]:
|
5307
|
+
raise ValueError(
|
5308
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
5309
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
5310
|
+
)
|
5311
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
5312
|
+
raise ValueError(
|
5313
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
5314
|
+
f"number of scales ({len(scale)}) at index {index}"
|
5315
|
+
)
|
5316
|
+
else:
|
5317
|
+
ip_adapter_masks = [None] * len(self.scale)
|
5318
|
+
|
5319
|
+
# for ip-adapter
|
5320
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
5321
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
5322
|
+
):
|
5323
|
+
skip = False
|
5324
|
+
if isinstance(scale, list):
|
5325
|
+
if all(s == 0 for s in scale):
|
5326
|
+
skip = True
|
5327
|
+
elif scale == 0:
|
5328
|
+
skip = True
|
5329
|
+
if not skip:
|
5330
|
+
if mask is not None:
|
5331
|
+
mask = mask.to(torch.float16)
|
5332
|
+
if not isinstance(scale, list):
|
5333
|
+
scale = [scale] * mask.shape[1]
|
5334
|
+
|
5335
|
+
current_num_images = mask.shape[1]
|
5336
|
+
for i in range(current_num_images):
|
5337
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
5338
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
5339
|
+
|
5340
|
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
5341
|
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
5342
|
+
|
5343
|
+
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
5344
|
+
query, ip_key, ip_value, op=self.attention_op
|
5345
|
+
)
|
5346
|
+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
5347
|
+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
5348
|
+
|
5349
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
5350
|
+
mask[:, i, :, :],
|
5351
|
+
batch_size,
|
5352
|
+
_current_ip_hidden_states.shape[1],
|
5353
|
+
_current_ip_hidden_states.shape[2],
|
5354
|
+
)
|
5355
|
+
|
5356
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
5357
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
5358
|
+
else:
|
5359
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
5360
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
5361
|
+
|
5362
|
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
5363
|
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
5364
|
+
|
5365
|
+
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
5366
|
+
query, ip_key, ip_value, op=self.attention_op
|
5367
|
+
)
|
5368
|
+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
5369
|
+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
5370
|
+
|
5371
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
5372
|
+
|
5373
|
+
# linear proj
|
5374
|
+
hidden_states = attn.to_out[0](hidden_states)
|
5375
|
+
# dropout
|
5376
|
+
hidden_states = attn.to_out[1](hidden_states)
|
5377
|
+
|
5378
|
+
if input_ndim == 4:
|
5379
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
5380
|
+
|
5381
|
+
if attn.residual_connection:
|
5382
|
+
hidden_states = hidden_states + residual
|
5383
|
+
|
5384
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
5385
|
+
|
5386
|
+
return hidden_states
|
5387
|
+
|
5388
|
+
|
5389
|
+
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
5390
|
+
"""
|
5391
|
+
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
|
5392
|
+
additional image-based information and timestep embeddings.
|
5393
|
+
|
5394
|
+
Args:
|
5395
|
+
hidden_size (`int`):
|
5396
|
+
The number of hidden channels.
|
5397
|
+
ip_hidden_states_dim (`int`):
|
5398
|
+
The image feature dimension.
|
5399
|
+
head_dim (`int`):
|
5400
|
+
The number of head channels.
|
5401
|
+
timesteps_emb_dim (`int`, defaults to 1280):
|
5402
|
+
The number of input channels for timestep embedding.
|
5403
|
+
scale (`float`, defaults to 0.5):
|
5404
|
+
IP-Adapter scale.
|
5405
|
+
"""
|
5406
|
+
|
5407
|
+
def __init__(
|
5408
|
+
self,
|
5409
|
+
hidden_size: int,
|
5410
|
+
ip_hidden_states_dim: int,
|
5411
|
+
head_dim: int,
|
5412
|
+
timesteps_emb_dim: int = 1280,
|
5413
|
+
scale: float = 0.5,
|
5414
|
+
):
|
5415
|
+
super().__init__()
|
5416
|
+
|
5417
|
+
# To prevent circular import
|
5418
|
+
from .normalization import AdaLayerNorm, RMSNorm
|
5419
|
+
|
5420
|
+
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
|
5421
|
+
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
5422
|
+
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
5423
|
+
self.norm_q = RMSNorm(head_dim, 1e-6)
|
5424
|
+
self.norm_k = RMSNorm(head_dim, 1e-6)
|
5425
|
+
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
5426
|
+
self.scale = scale
|
5427
|
+
|
5428
|
+
def __call__(
|
5429
|
+
self,
|
5430
|
+
attn: Attention,
|
5431
|
+
hidden_states: torch.FloatTensor,
|
5432
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
5433
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
5434
|
+
ip_hidden_states: torch.FloatTensor = None,
|
5435
|
+
temb: torch.FloatTensor = None,
|
5436
|
+
) -> torch.FloatTensor:
|
5437
|
+
"""
|
5438
|
+
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
|
5439
|
+
|
5440
|
+
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
|
5441
|
+
|
5442
|
+
Args:
|
5443
|
+
attn (`Attention`):
|
5444
|
+
Attention instance.
|
5445
|
+
hidden_states (`torch.FloatTensor`):
|
5446
|
+
Input `hidden_states`.
|
5447
|
+
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
5448
|
+
The encoder hidden states.
|
5449
|
+
attention_mask (`torch.FloatTensor`, *optional*):
|
5450
|
+
Attention mask.
|
5451
|
+
ip_hidden_states (`torch.FloatTensor`, *optional*):
|
5452
|
+
Image embeddings.
|
5453
|
+
temb (`torch.FloatTensor`, *optional*):
|
5454
|
+
Timestep embeddings.
|
5455
|
+
|
5456
|
+
Returns:
|
5457
|
+
`torch.FloatTensor`: Output hidden states.
|
5458
|
+
"""
|
5459
|
+
residual = hidden_states
|
5460
|
+
|
5461
|
+
batch_size = hidden_states.shape[0]
|
5462
|
+
|
5463
|
+
# `sample` projections.
|
5464
|
+
query = attn.to_q(hidden_states)
|
5465
|
+
key = attn.to_k(hidden_states)
|
5466
|
+
value = attn.to_v(hidden_states)
|
5467
|
+
|
3901
5468
|
inner_dim = key.shape[-1]
|
3902
5469
|
head_dim = inner_dim // attn.heads
|
3903
5470
|
|
3904
5471
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3905
|
-
|
3906
5472
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3907
5473
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5474
|
+
img_query = query
|
5475
|
+
img_key = key
|
5476
|
+
img_value = value
|
3908
5477
|
|
3909
|
-
|
3910
|
-
|
3911
|
-
|
3912
|
-
|
3913
|
-
)
|
3914
|
-
|
3915
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3916
|
-
hidden_states = hidden_states.to(query.dtype)
|
5478
|
+
if attn.norm_q is not None:
|
5479
|
+
query = attn.norm_q(query)
|
5480
|
+
if attn.norm_k is not None:
|
5481
|
+
key = attn.norm_k(key)
|
3917
5482
|
|
3918
|
-
|
3919
|
-
|
3920
|
-
|
3921
|
-
|
3922
|
-
|
3923
|
-
raise ValueError(
|
3924
|
-
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
3925
|
-
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
3926
|
-
f"({len(ip_hidden_states)})"
|
3927
|
-
)
|
3928
|
-
else:
|
3929
|
-
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
3930
|
-
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
3931
|
-
raise ValueError(
|
3932
|
-
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
3933
|
-
"[1, num_images_for_ip_adapter, height, width]."
|
3934
|
-
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
3935
|
-
)
|
3936
|
-
if mask.shape[1] != ip_state.shape[1]:
|
3937
|
-
raise ValueError(
|
3938
|
-
f"Number of masks ({mask.shape[1]}) does not match "
|
3939
|
-
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
3940
|
-
)
|
3941
|
-
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
3942
|
-
raise ValueError(
|
3943
|
-
f"Number of masks ({mask.shape[1]}) does not match "
|
3944
|
-
f"number of scales ({len(scale)}) at index {index}"
|
3945
|
-
)
|
3946
|
-
else:
|
3947
|
-
ip_adapter_masks = [None] * len(self.scale)
|
5483
|
+
# `context` projections.
|
5484
|
+
if encoder_hidden_states is not None:
|
5485
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
5486
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
5487
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
3948
5488
|
|
3949
|
-
|
3950
|
-
|
3951
|
-
|
3952
|
-
|
3953
|
-
|
3954
|
-
|
3955
|
-
|
3956
|
-
|
3957
|
-
|
3958
|
-
skip = True
|
3959
|
-
if not skip:
|
3960
|
-
if mask is not None:
|
3961
|
-
if not isinstance(scale, list):
|
3962
|
-
scale = [scale] * mask.shape[1]
|
5489
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
5490
|
+
batch_size, -1, attn.heads, head_dim
|
5491
|
+
).transpose(1, 2)
|
5492
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
5493
|
+
batch_size, -1, attn.heads, head_dim
|
5494
|
+
).transpose(1, 2)
|
5495
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
5496
|
+
batch_size, -1, attn.heads, head_dim
|
5497
|
+
).transpose(1, 2)
|
3963
5498
|
|
3964
|
-
|
3965
|
-
|
3966
|
-
|
3967
|
-
|
5499
|
+
if attn.norm_added_q is not None:
|
5500
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
5501
|
+
if attn.norm_added_k is not None:
|
5502
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
3968
5503
|
|
3969
|
-
|
3970
|
-
|
5504
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
5505
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
5506
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
3971
5507
|
|
3972
|
-
|
3973
|
-
|
3974
|
-
|
3975
|
-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
3976
|
-
)
|
5508
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
5509
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
5510
|
+
hidden_states = hidden_states.to(query.dtype)
|
3977
5511
|
|
3978
|
-
|
3979
|
-
|
3980
|
-
|
3981
|
-
|
5512
|
+
if encoder_hidden_states is not None:
|
5513
|
+
# Split the attention outputs.
|
5514
|
+
hidden_states, encoder_hidden_states = (
|
5515
|
+
hidden_states[:, : residual.shape[1]],
|
5516
|
+
hidden_states[:, residual.shape[1] :],
|
5517
|
+
)
|
5518
|
+
if not attn.context_pre_only:
|
5519
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
3982
5520
|
|
3983
|
-
|
3984
|
-
|
3985
|
-
|
3986
|
-
|
3987
|
-
_current_ip_hidden_states.shape[2],
|
3988
|
-
)
|
5521
|
+
# IP Adapter
|
5522
|
+
if self.scale != 0 and ip_hidden_states is not None:
|
5523
|
+
# Norm image features
|
5524
|
+
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
|
3989
5525
|
|
3990
|
-
|
3991
|
-
|
3992
|
-
|
3993
|
-
ip_key = to_k_ip(current_ip_hidden_states)
|
3994
|
-
ip_value = to_v_ip(current_ip_hidden_states)
|
5526
|
+
# To k and v
|
5527
|
+
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
5528
|
+
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
3995
5529
|
|
3996
|
-
|
3997
|
-
|
5530
|
+
# Reshape
|
5531
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5532
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3998
5533
|
|
3999
|
-
|
4000
|
-
|
4001
|
-
|
4002
|
-
|
4003
|
-
)
|
5534
|
+
# Norm
|
5535
|
+
query = self.norm_q(img_query)
|
5536
|
+
img_key = self.norm_k(img_key)
|
5537
|
+
ip_key = self.norm_ip_k(ip_key)
|
4004
5538
|
|
4005
|
-
|
4006
|
-
|
4007
|
-
|
4008
|
-
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
5539
|
+
# cat img
|
5540
|
+
key = torch.cat([img_key, ip_key], dim=2)
|
5541
|
+
value = torch.cat([img_value, ip_value], dim=2)
|
4009
5542
|
|
4010
|
-
|
5543
|
+
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
5544
|
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
|
5545
|
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
5546
|
+
|
5547
|
+
hidden_states = hidden_states + ip_hidden_states * self.scale
|
4011
5548
|
|
4012
5549
|
# linear proj
|
4013
5550
|
hidden_states = attn.to_out[0](hidden_states)
|
4014
5551
|
# dropout
|
4015
5552
|
hidden_states = attn.to_out[1](hidden_states)
|
4016
5553
|
|
4017
|
-
if
|
4018
|
-
|
4019
|
-
|
4020
|
-
|
4021
|
-
hidden_states = hidden_states + residual
|
4022
|
-
|
4023
|
-
hidden_states = hidden_states / attn.rescale_output_factor
|
4024
|
-
|
4025
|
-
return hidden_states
|
5554
|
+
if encoder_hidden_states is not None:
|
5555
|
+
return hidden_states, encoder_hidden_states
|
5556
|
+
else:
|
5557
|
+
return hidden_states
|
4026
5558
|
|
4027
5559
|
|
4028
5560
|
class PAGIdentitySelfAttnProcessor2_0:
|
@@ -4227,26 +5759,272 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
|
|
4227
5759
|
return hidden_states
|
4228
5760
|
|
4229
5761
|
|
5762
|
+
class SanaMultiscaleAttnProcessor2_0:
|
5763
|
+
r"""
|
5764
|
+
Processor for implementing multiscale quadratic attention.
|
5765
|
+
"""
|
5766
|
+
|
5767
|
+
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
|
5768
|
+
height, width = hidden_states.shape[-2:]
|
5769
|
+
if height * width > attn.attention_head_dim:
|
5770
|
+
use_linear_attention = True
|
5771
|
+
else:
|
5772
|
+
use_linear_attention = False
|
5773
|
+
|
5774
|
+
residual = hidden_states
|
5775
|
+
|
5776
|
+
batch_size, _, height, width = list(hidden_states.size())
|
5777
|
+
original_dtype = hidden_states.dtype
|
5778
|
+
|
5779
|
+
hidden_states = hidden_states.movedim(1, -1)
|
5780
|
+
query = attn.to_q(hidden_states)
|
5781
|
+
key = attn.to_k(hidden_states)
|
5782
|
+
value = attn.to_v(hidden_states)
|
5783
|
+
hidden_states = torch.cat([query, key, value], dim=3)
|
5784
|
+
hidden_states = hidden_states.movedim(-1, 1)
|
5785
|
+
|
5786
|
+
multi_scale_qkv = [hidden_states]
|
5787
|
+
for block in attn.to_qkv_multiscale:
|
5788
|
+
multi_scale_qkv.append(block(hidden_states))
|
5789
|
+
|
5790
|
+
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
5791
|
+
|
5792
|
+
if use_linear_attention:
|
5793
|
+
# for linear attention upcast hidden_states to float32
|
5794
|
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
5795
|
+
|
5796
|
+
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
|
5797
|
+
|
5798
|
+
query, key, value = hidden_states.chunk(3, dim=2)
|
5799
|
+
query = attn.nonlinearity(query)
|
5800
|
+
key = attn.nonlinearity(key)
|
5801
|
+
|
5802
|
+
if use_linear_attention:
|
5803
|
+
hidden_states = attn.apply_linear_attention(query, key, value)
|
5804
|
+
hidden_states = hidden_states.to(dtype=original_dtype)
|
5805
|
+
else:
|
5806
|
+
hidden_states = attn.apply_quadratic_attention(query, key, value)
|
5807
|
+
|
5808
|
+
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
5809
|
+
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
5810
|
+
|
5811
|
+
if attn.norm_type == "rms_norm":
|
5812
|
+
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
5813
|
+
else:
|
5814
|
+
hidden_states = attn.norm_out(hidden_states)
|
5815
|
+
|
5816
|
+
if attn.residual_connection:
|
5817
|
+
hidden_states = hidden_states + residual
|
5818
|
+
|
5819
|
+
return hidden_states
|
5820
|
+
|
5821
|
+
|
4230
5822
|
class LoRAAttnProcessor:
|
5823
|
+
r"""
|
5824
|
+
Processor for implementing attention with LoRA.
|
5825
|
+
"""
|
5826
|
+
|
4231
5827
|
def __init__(self):
|
4232
5828
|
pass
|
4233
5829
|
|
4234
5830
|
|
4235
5831
|
class LoRAAttnProcessor2_0:
|
5832
|
+
r"""
|
5833
|
+
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
|
5834
|
+
"""
|
5835
|
+
|
4236
5836
|
def __init__(self):
|
4237
5837
|
pass
|
4238
5838
|
|
4239
5839
|
|
4240
5840
|
class LoRAXFormersAttnProcessor:
|
5841
|
+
r"""
|
5842
|
+
Processor for implementing attention with LoRA using xFormers.
|
5843
|
+
"""
|
5844
|
+
|
4241
5845
|
def __init__(self):
|
4242
5846
|
pass
|
4243
5847
|
|
4244
5848
|
|
4245
5849
|
class LoRAAttnAddedKVProcessor:
|
5850
|
+
r"""
|
5851
|
+
Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
|
5852
|
+
"""
|
5853
|
+
|
4246
5854
|
def __init__(self):
|
4247
5855
|
pass
|
4248
5856
|
|
4249
5857
|
|
5858
|
+
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
5859
|
+
r"""
|
5860
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
5861
|
+
"""
|
5862
|
+
|
5863
|
+
def __init__(self):
|
5864
|
+
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
5865
|
+
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
5866
|
+
super().__init__()
|
5867
|
+
|
5868
|
+
|
5869
|
+
class SanaLinearAttnProcessor2_0:
|
5870
|
+
r"""
|
5871
|
+
Processor for implementing scaled dot-product linear attention.
|
5872
|
+
"""
|
5873
|
+
|
5874
|
+
def __call__(
|
5875
|
+
self,
|
5876
|
+
attn: Attention,
|
5877
|
+
hidden_states: torch.Tensor,
|
5878
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5879
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5880
|
+
) -> torch.Tensor:
|
5881
|
+
original_dtype = hidden_states.dtype
|
5882
|
+
|
5883
|
+
if encoder_hidden_states is None:
|
5884
|
+
encoder_hidden_states = hidden_states
|
5885
|
+
|
5886
|
+
query = attn.to_q(hidden_states)
|
5887
|
+
key = attn.to_k(encoder_hidden_states)
|
5888
|
+
value = attn.to_v(encoder_hidden_states)
|
5889
|
+
|
5890
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5891
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5892
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5893
|
+
|
5894
|
+
query = F.relu(query)
|
5895
|
+
key = F.relu(key)
|
5896
|
+
|
5897
|
+
query, key, value = query.float(), key.float(), value.float()
|
5898
|
+
|
5899
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
5900
|
+
scores = torch.matmul(value, key)
|
5901
|
+
hidden_states = torch.matmul(scores, query)
|
5902
|
+
|
5903
|
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
|
5904
|
+
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
5905
|
+
hidden_states = hidden_states.to(original_dtype)
|
5906
|
+
|
5907
|
+
hidden_states = attn.to_out[0](hidden_states)
|
5908
|
+
hidden_states = attn.to_out[1](hidden_states)
|
5909
|
+
|
5910
|
+
if original_dtype == torch.float16:
|
5911
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
5912
|
+
|
5913
|
+
return hidden_states
|
5914
|
+
|
5915
|
+
|
5916
|
+
class PAGCFGSanaLinearAttnProcessor2_0:
|
5917
|
+
r"""
|
5918
|
+
Processor for implementing scaled dot-product linear attention.
|
5919
|
+
"""
|
5920
|
+
|
5921
|
+
def __call__(
|
5922
|
+
self,
|
5923
|
+
attn: Attention,
|
5924
|
+
hidden_states: torch.Tensor,
|
5925
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5926
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5927
|
+
) -> torch.Tensor:
|
5928
|
+
original_dtype = hidden_states.dtype
|
5929
|
+
|
5930
|
+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
5931
|
+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
5932
|
+
|
5933
|
+
query = attn.to_q(hidden_states_org)
|
5934
|
+
key = attn.to_k(hidden_states_org)
|
5935
|
+
value = attn.to_v(hidden_states_org)
|
5936
|
+
|
5937
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5938
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5939
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5940
|
+
|
5941
|
+
query = F.relu(query)
|
5942
|
+
key = F.relu(key)
|
5943
|
+
|
5944
|
+
query, key, value = query.float(), key.float(), value.float()
|
5945
|
+
|
5946
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
5947
|
+
scores = torch.matmul(value, key)
|
5948
|
+
hidden_states_org = torch.matmul(scores, query)
|
5949
|
+
|
5950
|
+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
|
5951
|
+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
|
5952
|
+
hidden_states_org = hidden_states_org.to(original_dtype)
|
5953
|
+
|
5954
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
5955
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
5956
|
+
|
5957
|
+
# perturbed path (identity attention)
|
5958
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
|
5959
|
+
|
5960
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
5961
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
5962
|
+
|
5963
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
5964
|
+
|
5965
|
+
if original_dtype == torch.float16:
|
5966
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
5967
|
+
|
5968
|
+
return hidden_states
|
5969
|
+
|
5970
|
+
|
5971
|
+
class PAGIdentitySanaLinearAttnProcessor2_0:
|
5972
|
+
r"""
|
5973
|
+
Processor for implementing scaled dot-product linear attention.
|
5974
|
+
"""
|
5975
|
+
|
5976
|
+
def __call__(
|
5977
|
+
self,
|
5978
|
+
attn: Attention,
|
5979
|
+
hidden_states: torch.Tensor,
|
5980
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5981
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5982
|
+
) -> torch.Tensor:
|
5983
|
+
original_dtype = hidden_states.dtype
|
5984
|
+
|
5985
|
+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
5986
|
+
|
5987
|
+
query = attn.to_q(hidden_states_org)
|
5988
|
+
key = attn.to_k(hidden_states_org)
|
5989
|
+
value = attn.to_v(hidden_states_org)
|
5990
|
+
|
5991
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5992
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5993
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5994
|
+
|
5995
|
+
query = F.relu(query)
|
5996
|
+
key = F.relu(key)
|
5997
|
+
|
5998
|
+
query, key, value = query.float(), key.float(), value.float()
|
5999
|
+
|
6000
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
6001
|
+
scores = torch.matmul(value, key)
|
6002
|
+
hidden_states_org = torch.matmul(scores, query)
|
6003
|
+
|
6004
|
+
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
|
6005
|
+
hidden_states_org = hidden_states_org.float()
|
6006
|
+
|
6007
|
+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
|
6008
|
+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
|
6009
|
+
hidden_states_org = hidden_states_org.to(original_dtype)
|
6010
|
+
|
6011
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
6012
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
6013
|
+
|
6014
|
+
# perturbed path (identity attention)
|
6015
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
|
6016
|
+
|
6017
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
6018
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
6019
|
+
|
6020
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
6021
|
+
|
6022
|
+
if original_dtype == torch.float16:
|
6023
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
6024
|
+
|
6025
|
+
return hidden_states
|
6026
|
+
|
6027
|
+
|
4250
6028
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
4251
6029
|
AttnAddedKVProcessor,
|
4252
6030
|
SlicedAttnAddedKVProcessor,
|
@@ -4261,23 +6039,59 @@ CROSS_ATTENTION_PROCESSORS = (
|
|
4261
6039
|
SlicedAttnProcessor,
|
4262
6040
|
IPAdapterAttnProcessor,
|
4263
6041
|
IPAdapterAttnProcessor2_0,
|
6042
|
+
FluxIPAdapterJointAttnProcessor2_0,
|
4264
6043
|
)
|
4265
6044
|
|
4266
6045
|
AttentionProcessor = Union[
|
4267
6046
|
AttnProcessor,
|
4268
|
-
|
4269
|
-
FusedAttnProcessor2_0,
|
4270
|
-
XFormersAttnProcessor,
|
4271
|
-
SlicedAttnProcessor,
|
6047
|
+
CustomDiffusionAttnProcessor,
|
4272
6048
|
AttnAddedKVProcessor,
|
4273
|
-
SlicedAttnAddedKVProcessor,
|
4274
6049
|
AttnAddedKVProcessor2_0,
|
6050
|
+
JointAttnProcessor2_0,
|
6051
|
+
PAGJointAttnProcessor2_0,
|
6052
|
+
PAGCFGJointAttnProcessor2_0,
|
6053
|
+
FusedJointAttnProcessor2_0,
|
6054
|
+
AllegroAttnProcessor2_0,
|
6055
|
+
AuraFlowAttnProcessor2_0,
|
6056
|
+
FusedAuraFlowAttnProcessor2_0,
|
6057
|
+
FluxAttnProcessor2_0,
|
6058
|
+
FluxAttnProcessor2_0_NPU,
|
6059
|
+
FusedFluxAttnProcessor2_0,
|
6060
|
+
FusedFluxAttnProcessor2_0_NPU,
|
6061
|
+
CogVideoXAttnProcessor2_0,
|
6062
|
+
FusedCogVideoXAttnProcessor2_0,
|
4275
6063
|
XFormersAttnAddedKVProcessor,
|
4276
|
-
|
6064
|
+
XFormersAttnProcessor,
|
6065
|
+
XLAFlashAttnProcessor2_0,
|
6066
|
+
AttnProcessorNPU,
|
6067
|
+
AttnProcessor2_0,
|
6068
|
+
MochiVaeAttnProcessor2_0,
|
6069
|
+
MochiAttnProcessor2_0,
|
6070
|
+
StableAudioAttnProcessor2_0,
|
6071
|
+
HunyuanAttnProcessor2_0,
|
6072
|
+
FusedHunyuanAttnProcessor2_0,
|
6073
|
+
PAGHunyuanAttnProcessor2_0,
|
6074
|
+
PAGCFGHunyuanAttnProcessor2_0,
|
6075
|
+
LuminaAttnProcessor2_0,
|
6076
|
+
FusedAttnProcessor2_0,
|
4277
6077
|
CustomDiffusionXFormersAttnProcessor,
|
4278
6078
|
CustomDiffusionAttnProcessor2_0,
|
4279
|
-
|
6079
|
+
SlicedAttnProcessor,
|
6080
|
+
SlicedAttnAddedKVProcessor,
|
6081
|
+
SanaLinearAttnProcessor2_0,
|
6082
|
+
PAGCFGSanaLinearAttnProcessor2_0,
|
6083
|
+
PAGIdentitySanaLinearAttnProcessor2_0,
|
6084
|
+
SanaMultiscaleLinearAttention,
|
6085
|
+
SanaMultiscaleAttnProcessor2_0,
|
6086
|
+
SanaMultiscaleAttentionProjection,
|
6087
|
+
IPAdapterAttnProcessor,
|
6088
|
+
IPAdapterAttnProcessor2_0,
|
6089
|
+
IPAdapterXFormersAttnProcessor,
|
6090
|
+
SD3IPAdapterJointAttnProcessor2_0,
|
4280
6091
|
PAGIdentitySelfAttnProcessor2_0,
|
4281
|
-
|
4282
|
-
|
6092
|
+
PAGCFGIdentitySelfAttnProcessor2_0,
|
6093
|
+
LoRAAttnProcessor,
|
6094
|
+
LoRAAttnProcessor2_0,
|
6095
|
+
LoRAXFormersAttnProcessor,
|
6096
|
+
LoRAAttnAddedKVProcessor,
|
4283
6097
|
]
|