diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import inspect
|
15
15
|
import math
|
16
|
-
from
|
17
|
-
from typing import Callable, List, Optional, Union
|
16
|
+
from typing import Callable, List, Optional, Tuple, Union
|
18
17
|
|
19
18
|
import torch
|
20
19
|
import torch.nn.functional as F
|
@@ -23,8 +22,7 @@ from torch import nn
|
|
23
22
|
from ..image_processor import IPAdapterMaskProcessor
|
24
23
|
from ..utils import deprecate, logging
|
25
24
|
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
26
|
-
from ..utils.torch_utils import maybe_allow_in_graph
|
27
|
-
from .lora import LoRALinearLayer
|
25
|
+
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
28
26
|
|
29
27
|
|
30
28
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -51,6 +49,10 @@ class Attention(nn.Module):
|
|
51
49
|
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
52
50
|
heads (`int`, *optional*, defaults to 8):
|
53
51
|
The number of heads to use for multi-head attention.
|
52
|
+
kv_heads (`int`, *optional*, defaults to `None`):
|
53
|
+
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
54
|
+
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
55
|
+
Query Attention (MQA) otherwise GQA is used.
|
54
56
|
dim_head (`int`, *optional*, defaults to 64):
|
55
57
|
The number of channels in each head.
|
56
58
|
dropout (`float`, *optional*, defaults to 0.0):
|
@@ -96,6 +98,7 @@ class Attention(nn.Module):
|
|
96
98
|
query_dim: int,
|
97
99
|
cross_attention_dim: Optional[int] = None,
|
98
100
|
heads: int = 8,
|
101
|
+
kv_heads: Optional[int] = None,
|
99
102
|
dim_head: int = 64,
|
100
103
|
dropout: float = 0.0,
|
101
104
|
bias: bool = False,
|
@@ -105,6 +108,7 @@ class Attention(nn.Module):
|
|
105
108
|
cross_attention_norm_num_groups: int = 32,
|
106
109
|
qk_norm: Optional[str] = None,
|
107
110
|
added_kv_proj_dim: Optional[int] = None,
|
111
|
+
added_proj_bias: Optional[bool] = True,
|
108
112
|
norm_num_groups: Optional[int] = None,
|
109
113
|
spatial_norm_dim: Optional[int] = None,
|
110
114
|
out_bias: bool = True,
|
@@ -117,9 +121,15 @@ class Attention(nn.Module):
|
|
117
121
|
processor: Optional["AttnProcessor"] = None,
|
118
122
|
out_dim: int = None,
|
119
123
|
context_pre_only=None,
|
124
|
+
pre_only=False,
|
120
125
|
):
|
121
126
|
super().__init__()
|
127
|
+
|
128
|
+
# To prevent circular import.
|
129
|
+
from .normalization import FP32LayerNorm, RMSNorm
|
130
|
+
|
122
131
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
132
|
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
123
133
|
self.query_dim = query_dim
|
124
134
|
self.use_bias = bias
|
125
135
|
self.is_cross_attention = cross_attention_dim is not None
|
@@ -132,6 +142,7 @@ class Attention(nn.Module):
|
|
132
142
|
self.fused_projections = False
|
133
143
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
134
144
|
self.context_pre_only = context_pre_only
|
145
|
+
self.pre_only = pre_only
|
135
146
|
|
136
147
|
# we make use of this private variable to know whether this class is loaded
|
137
148
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -170,6 +181,16 @@ class Attention(nn.Module):
|
|
170
181
|
elif qk_norm == "layer_norm":
|
171
182
|
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
172
183
|
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
184
|
+
elif qk_norm == "fp32_layer_norm":
|
185
|
+
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
186
|
+
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
187
|
+
elif qk_norm == "layer_norm_across_heads":
|
188
|
+
# Lumina applys qk norm across all heads
|
189
|
+
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
190
|
+
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
191
|
+
elif qk_norm == "rms_norm":
|
192
|
+
self.norm_q = RMSNorm(dim_head, eps=eps)
|
193
|
+
self.norm_k = RMSNorm(dim_head, eps=eps)
|
173
194
|
else:
|
174
195
|
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
175
196
|
|
@@ -200,25 +221,38 @@ class Attention(nn.Module):
|
|
200
221
|
|
201
222
|
if not self.only_cross_attention:
|
202
223
|
# only relevant for the `AddedKVProcessor` classes
|
203
|
-
self.to_k = nn.Linear(self.cross_attention_dim, self.
|
204
|
-
self.to_v = nn.Linear(self.cross_attention_dim, self.
|
224
|
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
225
|
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
205
226
|
else:
|
206
227
|
self.to_k = None
|
207
228
|
self.to_v = None
|
208
229
|
|
230
|
+
self.added_proj_bias = added_proj_bias
|
209
231
|
if self.added_kv_proj_dim is not None:
|
210
|
-
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.
|
211
|
-
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.
|
232
|
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
233
|
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
212
234
|
if self.context_pre_only is not None:
|
213
|
-
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
235
|
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
214
236
|
|
215
|
-
|
216
|
-
|
217
|
-
|
237
|
+
if not self.pre_only:
|
238
|
+
self.to_out = nn.ModuleList([])
|
239
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
240
|
+
self.to_out.append(nn.Dropout(dropout))
|
218
241
|
|
219
242
|
if self.context_pre_only is not None and not self.context_pre_only:
|
220
243
|
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
221
244
|
|
245
|
+
if qk_norm is not None and added_kv_proj_dim is not None:
|
246
|
+
if qk_norm == "fp32_layer_norm":
|
247
|
+
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
248
|
+
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
249
|
+
elif qk_norm == "rms_norm":
|
250
|
+
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
251
|
+
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
252
|
+
else:
|
253
|
+
self.norm_added_q = None
|
254
|
+
self.norm_added_k = None
|
255
|
+
|
222
256
|
# set attention processor
|
223
257
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
224
258
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
@@ -259,10 +293,6 @@ class Attention(nn.Module):
|
|
259
293
|
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
260
294
|
`xformers`.
|
261
295
|
"""
|
262
|
-
is_lora = hasattr(self, "processor") and isinstance(
|
263
|
-
self.processor,
|
264
|
-
LORA_ATTENTION_PROCESSORS,
|
265
|
-
)
|
266
296
|
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
267
297
|
self.processor,
|
268
298
|
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
|
@@ -274,14 +304,13 @@ class Attention(nn.Module):
|
|
274
304
|
AttnAddedKVProcessor2_0,
|
275
305
|
SlicedAttnAddedKVProcessor,
|
276
306
|
XFormersAttnAddedKVProcessor,
|
277
|
-
LoRAAttnAddedKVProcessor,
|
278
307
|
),
|
279
308
|
)
|
280
309
|
|
281
310
|
if use_memory_efficient_attention_xformers:
|
282
|
-
if is_added_kv_processor and
|
311
|
+
if is_added_kv_processor and is_custom_diffusion:
|
283
312
|
raise NotImplementedError(
|
284
|
-
f"Memory efficient attention is currently not supported for
|
313
|
+
f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
|
285
314
|
)
|
286
315
|
if not is_xformers_available():
|
287
316
|
raise ModuleNotFoundError(
|
@@ -307,18 +336,7 @@ class Attention(nn.Module):
|
|
307
336
|
except Exception as e:
|
308
337
|
raise e
|
309
338
|
|
310
|
-
if
|
311
|
-
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
312
|
-
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
313
|
-
processor = LoRAXFormersAttnProcessor(
|
314
|
-
hidden_size=self.processor.hidden_size,
|
315
|
-
cross_attention_dim=self.processor.cross_attention_dim,
|
316
|
-
rank=self.processor.rank,
|
317
|
-
attention_op=attention_op,
|
318
|
-
)
|
319
|
-
processor.load_state_dict(self.processor.state_dict())
|
320
|
-
processor.to(self.processor.to_q_lora.up.weight.device)
|
321
|
-
elif is_custom_diffusion:
|
339
|
+
if is_custom_diffusion:
|
322
340
|
processor = CustomDiffusionXFormersAttnProcessor(
|
323
341
|
train_kv=self.processor.train_kv,
|
324
342
|
train_q_out=self.processor.train_q_out,
|
@@ -341,18 +359,7 @@ class Attention(nn.Module):
|
|
341
359
|
else:
|
342
360
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
343
361
|
else:
|
344
|
-
if
|
345
|
-
attn_processor_class = (
|
346
|
-
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
347
|
-
)
|
348
|
-
processor = attn_processor_class(
|
349
|
-
hidden_size=self.processor.hidden_size,
|
350
|
-
cross_attention_dim=self.processor.cross_attention_dim,
|
351
|
-
rank=self.processor.rank,
|
352
|
-
)
|
353
|
-
processor.load_state_dict(self.processor.state_dict())
|
354
|
-
processor.to(self.processor.to_q_lora.up.weight.device)
|
355
|
-
elif is_custom_diffusion:
|
362
|
+
if is_custom_diffusion:
|
356
363
|
attn_processor_class = (
|
357
364
|
CustomDiffusionAttnProcessor2_0
|
358
365
|
if hasattr(F, "scaled_dot_product_attention")
|
@@ -442,82 +449,6 @@ class Attention(nn.Module):
|
|
442
449
|
if not return_deprecated_lora:
|
443
450
|
return self.processor
|
444
451
|
|
445
|
-
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
446
|
-
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
447
|
-
# with PEFT is completed.
|
448
|
-
is_lora_activated = {
|
449
|
-
name: module.lora_layer is not None
|
450
|
-
for name, module in self.named_modules()
|
451
|
-
if hasattr(module, "lora_layer")
|
452
|
-
}
|
453
|
-
|
454
|
-
# 1. if no layer has a LoRA activated we can return the processor as usual
|
455
|
-
if not any(is_lora_activated.values()):
|
456
|
-
return self.processor
|
457
|
-
|
458
|
-
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
459
|
-
is_lora_activated.pop("add_k_proj", None)
|
460
|
-
is_lora_activated.pop("add_v_proj", None)
|
461
|
-
# 2. else it is not possible that only some layers have LoRA activated
|
462
|
-
if not all(is_lora_activated.values()):
|
463
|
-
raise ValueError(
|
464
|
-
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
465
|
-
)
|
466
|
-
|
467
|
-
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
468
|
-
non_lora_processor_cls_name = self.processor.__class__.__name__
|
469
|
-
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
|
470
|
-
|
471
|
-
hidden_size = self.inner_dim
|
472
|
-
|
473
|
-
# now create a LoRA attention processor from the LoRA layers
|
474
|
-
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
|
475
|
-
kwargs = {
|
476
|
-
"cross_attention_dim": self.cross_attention_dim,
|
477
|
-
"rank": self.to_q.lora_layer.rank,
|
478
|
-
"network_alpha": self.to_q.lora_layer.network_alpha,
|
479
|
-
"q_rank": self.to_q.lora_layer.rank,
|
480
|
-
"q_hidden_size": self.to_q.lora_layer.out_features,
|
481
|
-
"k_rank": self.to_k.lora_layer.rank,
|
482
|
-
"k_hidden_size": self.to_k.lora_layer.out_features,
|
483
|
-
"v_rank": self.to_v.lora_layer.rank,
|
484
|
-
"v_hidden_size": self.to_v.lora_layer.out_features,
|
485
|
-
"out_rank": self.to_out[0].lora_layer.rank,
|
486
|
-
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
487
|
-
}
|
488
|
-
|
489
|
-
if hasattr(self.processor, "attention_op"):
|
490
|
-
kwargs["attention_op"] = self.processor.attention_op
|
491
|
-
|
492
|
-
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
493
|
-
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
494
|
-
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
495
|
-
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
496
|
-
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
497
|
-
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
498
|
-
lora_processor = lora_processor_cls(
|
499
|
-
hidden_size,
|
500
|
-
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
501
|
-
rank=self.to_q.lora_layer.rank,
|
502
|
-
network_alpha=self.to_q.lora_layer.network_alpha,
|
503
|
-
)
|
504
|
-
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
505
|
-
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
506
|
-
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
507
|
-
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
|
508
|
-
|
509
|
-
# only save if used
|
510
|
-
if self.add_k_proj.lora_layer is not None:
|
511
|
-
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
|
512
|
-
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
|
513
|
-
else:
|
514
|
-
lora_processor.add_k_proj_lora = None
|
515
|
-
lora_processor.add_v_proj_lora = None
|
516
|
-
else:
|
517
|
-
raise ValueError(f"{lora_processor_cls} does not exist.")
|
518
|
-
|
519
|
-
return lora_processor
|
520
|
-
|
521
452
|
def forward(
|
522
453
|
self,
|
523
454
|
hidden_states: torch.Tensor,
|
@@ -609,7 +540,7 @@ class Attention(nn.Module):
|
|
609
540
|
return tensor
|
610
541
|
|
611
542
|
def get_attention_scores(
|
612
|
-
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
|
543
|
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
613
544
|
) -> torch.Tensor:
|
614
545
|
r"""
|
615
546
|
Compute the attention scores.
|
@@ -760,6 +691,24 @@ class Attention(nn.Module):
|
|
760
691
|
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
761
692
|
self.to_kv.bias.copy_(concatenated_bias)
|
762
693
|
|
694
|
+
# handle added projections for SD3 and others.
|
695
|
+
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
|
696
|
+
concatenated_weights = torch.cat(
|
697
|
+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
698
|
+
)
|
699
|
+
in_features = concatenated_weights.shape[1]
|
700
|
+
out_features = concatenated_weights.shape[0]
|
701
|
+
|
702
|
+
self.to_added_qkv = nn.Linear(
|
703
|
+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
704
|
+
)
|
705
|
+
self.to_added_qkv.weight.copy_(concatenated_weights)
|
706
|
+
if self.added_proj_bias:
|
707
|
+
concatenated_bias = torch.cat(
|
708
|
+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
709
|
+
)
|
710
|
+
self.to_added_qkv.bias.copy_(concatenated_bias)
|
711
|
+
|
763
712
|
self.fused_projections = fuse
|
764
713
|
|
765
714
|
|
@@ -1132,9 +1081,7 @@ class JointAttnProcessor2_0:
|
|
1132
1081
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1133
1082
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1134
1083
|
|
1135
|
-
hidden_states =
|
1136
|
-
query, key, value, dropout_p=0.0, is_causal=False
|
1137
|
-
)
|
1084
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1138
1085
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1139
1086
|
hidden_states = hidden_states.to(query.dtype)
|
1140
1087
|
|
@@ -1159,21 +1106,20 @@ class JointAttnProcessor2_0:
|
|
1159
1106
|
return hidden_states, encoder_hidden_states
|
1160
1107
|
|
1161
1108
|
|
1162
|
-
class
|
1109
|
+
class PAGJointAttnProcessor2_0:
|
1163
1110
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1164
1111
|
|
1165
1112
|
def __init__(self):
|
1166
1113
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1167
|
-
raise ImportError(
|
1114
|
+
raise ImportError(
|
1115
|
+
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1116
|
+
)
|
1168
1117
|
|
1169
1118
|
def __call__(
|
1170
1119
|
self,
|
1171
1120
|
attn: Attention,
|
1172
1121
|
hidden_states: torch.FloatTensor,
|
1173
1122
|
encoder_hidden_states: torch.FloatTensor = None,
|
1174
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
1175
|
-
*args,
|
1176
|
-
**kwargs,
|
1177
1123
|
) -> torch.FloatTensor:
|
1178
1124
|
residual = hidden_states
|
1179
1125
|
|
@@ -1186,257 +1132,1409 @@ class FusedJointAttnProcessor2_0:
|
|
1186
1132
|
batch_size, channel, height, width = encoder_hidden_states.shape
|
1187
1133
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1188
1134
|
|
1189
|
-
|
1135
|
+
# store the length of image patch sequences to create a mask that prevents interaction between patches
|
1136
|
+
# similar to making the self-attention map an identity matrix
|
1137
|
+
identity_block_size = hidden_states.shape[1]
|
1138
|
+
|
1139
|
+
# chunk
|
1140
|
+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
1141
|
+
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
|
1142
|
+
|
1143
|
+
################## original path ##################
|
1144
|
+
batch_size = encoder_hidden_states_org.shape[0]
|
1190
1145
|
|
1191
1146
|
# `sample` projections.
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1147
|
+
query_org = attn.to_q(hidden_states_org)
|
1148
|
+
key_org = attn.to_k(hidden_states_org)
|
1149
|
+
value_org = attn.to_v(hidden_states_org)
|
1195
1150
|
|
1196
1151
|
# `context` projections.
|
1197
|
-
|
1198
|
-
|
1199
|
-
(
|
1200
|
-
encoder_hidden_states_query_proj,
|
1201
|
-
encoder_hidden_states_key_proj,
|
1202
|
-
encoder_hidden_states_value_proj,
|
1203
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
1152
|
+
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
1153
|
+
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
1154
|
+
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
1204
1155
|
|
1205
1156
|
# attention
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1157
|
+
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
1158
|
+
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
1159
|
+
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
1209
1160
|
|
1210
|
-
inner_dim =
|
1161
|
+
inner_dim = key_org.shape[-1]
|
1211
1162
|
head_dim = inner_dim // attn.heads
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1163
|
+
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1164
|
+
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1165
|
+
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1215
1166
|
|
1216
|
-
|
1217
|
-
|
1167
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
1168
|
+
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
1218
1169
|
)
|
1219
|
-
|
1220
|
-
|
1170
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1171
|
+
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
1221
1172
|
|
1222
1173
|
# Split the attention outputs.
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1174
|
+
hidden_states_org, encoder_hidden_states_org = (
|
1175
|
+
hidden_states_org[:, : residual.shape[1]],
|
1176
|
+
hidden_states_org[:, residual.shape[1] :],
|
1226
1177
|
)
|
1227
1178
|
|
1228
1179
|
# linear proj
|
1229
|
-
|
1180
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
1230
1181
|
# dropout
|
1231
|
-
|
1182
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
1232
1183
|
if not attn.context_pre_only:
|
1233
|
-
|
1184
|
+
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
1234
1185
|
|
1235
1186
|
if input_ndim == 4:
|
1236
|
-
|
1187
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1237
1188
|
if context_input_ndim == 4:
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1189
|
+
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
1190
|
+
batch_size, channel, height, width
|
1191
|
+
)
|
1241
1192
|
|
1193
|
+
################## perturbed path ##################
|
1242
1194
|
|
1243
|
-
|
1244
|
-
r"""
|
1245
|
-
Processor for implementing memory efficient attention using xFormers.
|
1195
|
+
batch_size = encoder_hidden_states_ptb.shape[0]
|
1246
1196
|
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1252
|
-
operator.
|
1253
|
-
"""
|
1197
|
+
# `sample` projections.
|
1198
|
+
query_ptb = attn.to_q(hidden_states_ptb)
|
1199
|
+
key_ptb = attn.to_k(hidden_states_ptb)
|
1200
|
+
value_ptb = attn.to_v(hidden_states_ptb)
|
1254
1201
|
|
1255
|
-
|
1256
|
-
|
1202
|
+
# `context` projections.
|
1203
|
+
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
1204
|
+
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
1205
|
+
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
1257
1206
|
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1263
|
-
attention_mask: Optional[torch.Tensor] = None,
|
1264
|
-
) -> torch.Tensor:
|
1265
|
-
residual = hidden_states
|
1266
|
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1267
|
-
batch_size, sequence_length, _ = hidden_states.shape
|
1207
|
+
# attention
|
1208
|
+
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
1209
|
+
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
1210
|
+
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
1268
1211
|
|
1269
|
-
|
1212
|
+
inner_dim = key_ptb.shape[-1]
|
1213
|
+
head_dim = inner_dim // attn.heads
|
1214
|
+
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1215
|
+
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1216
|
+
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1270
1217
|
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1218
|
+
# create a full mask with all entries set to 0
|
1219
|
+
seq_len = query_ptb.size(2)
|
1220
|
+
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
1275
1221
|
|
1276
|
-
|
1222
|
+
# set the attention value between image patches to -inf
|
1223
|
+
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
1277
1224
|
|
1278
|
-
|
1279
|
-
|
1225
|
+
# set the diagonal of the attention value between image patches to 0
|
1226
|
+
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
1280
1227
|
|
1281
|
-
|
1282
|
-
|
1283
|
-
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1284
|
-
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1228
|
+
# expand the mask to match the attention weights shape
|
1229
|
+
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
1285
1230
|
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1292
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1293
|
-
else:
|
1294
|
-
key = encoder_hidden_states_key_proj
|
1295
|
-
value = encoder_hidden_states_value_proj
|
1231
|
+
hidden_states_ptb = F.scaled_dot_product_attention(
|
1232
|
+
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
1233
|
+
)
|
1234
|
+
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1235
|
+
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
1296
1236
|
|
1297
|
-
|
1298
|
-
|
1237
|
+
# split the attention outputs.
|
1238
|
+
hidden_states_ptb, encoder_hidden_states_ptb = (
|
1239
|
+
hidden_states_ptb[:, : residual.shape[1]],
|
1240
|
+
hidden_states_ptb[:, residual.shape[1] :],
|
1299
1241
|
)
|
1300
|
-
hidden_states = hidden_states.to(query.dtype)
|
1301
|
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1302
1242
|
|
1303
1243
|
# linear proj
|
1304
|
-
|
1244
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
1305
1245
|
# dropout
|
1306
|
-
|
1246
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
1247
|
+
if not attn.context_pre_only:
|
1248
|
+
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
1307
1249
|
|
1308
|
-
|
1309
|
-
|
1250
|
+
if input_ndim == 4:
|
1251
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1252
|
+
if context_input_ndim == 4:
|
1253
|
+
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
1254
|
+
batch_size, channel, height, width
|
1255
|
+
)
|
1310
1256
|
|
1311
|
-
|
1257
|
+
################ concat ###############
|
1258
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
1259
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
1312
1260
|
|
1261
|
+
return hidden_states, encoder_hidden_states
|
1313
1262
|
|
1314
|
-
class XFormersAttnProcessor:
|
1315
|
-
r"""
|
1316
|
-
Processor for implementing memory efficient attention using xFormers.
|
1317
1263
|
|
1318
|
-
|
1319
|
-
|
1320
|
-
The base
|
1321
|
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1322
|
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1323
|
-
operator.
|
1324
|
-
"""
|
1264
|
+
class PAGCFGJointAttnProcessor2_0:
|
1265
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1325
1266
|
|
1326
|
-
def __init__(self
|
1327
|
-
|
1267
|
+
def __init__(self):
|
1268
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1269
|
+
raise ImportError(
|
1270
|
+
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1271
|
+
)
|
1328
1272
|
|
1329
1273
|
def __call__(
|
1330
1274
|
self,
|
1331
1275
|
attn: Attention,
|
1332
|
-
hidden_states: torch.
|
1333
|
-
encoder_hidden_states:
|
1334
|
-
attention_mask: Optional[torch.
|
1335
|
-
temb: Optional[torch.Tensor] = None,
|
1276
|
+
hidden_states: torch.FloatTensor,
|
1277
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1278
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1336
1279
|
*args,
|
1337
1280
|
**kwargs,
|
1338
|
-
) -> torch.
|
1339
|
-
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1340
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1341
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
1342
|
-
|
1281
|
+
) -> torch.FloatTensor:
|
1343
1282
|
residual = hidden_states
|
1344
1283
|
|
1345
|
-
if attn.spatial_norm is not None:
|
1346
|
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1347
|
-
|
1348
1284
|
input_ndim = hidden_states.ndim
|
1349
|
-
|
1350
1285
|
if input_ndim == 4:
|
1351
1286
|
batch_size, channel, height, width = hidden_states.shape
|
1352
1287
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1288
|
+
context_input_ndim = encoder_hidden_states.ndim
|
1289
|
+
if context_input_ndim == 4:
|
1290
|
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
1291
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1353
1292
|
|
1354
|
-
|
1355
|
-
|
1356
|
-
)
|
1293
|
+
identity_block_size = hidden_states.shape[
|
1294
|
+
1
|
1295
|
+
] # patch embeddings width * height (correspond to self-attention map width or height)
|
1357
1296
|
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
# [batch*heads, 1, key_tokens] ->
|
1362
|
-
# [batch*heads, query_tokens, key_tokens]
|
1363
|
-
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1364
|
-
# [batch*heads, query_tokens, key_tokens]
|
1365
|
-
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1366
|
-
_, query_tokens, _ = hidden_states.shape
|
1367
|
-
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1297
|
+
# chunk
|
1298
|
+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
1299
|
+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
1368
1300
|
|
1369
|
-
|
1370
|
-
|
1301
|
+
(
|
1302
|
+
encoder_hidden_states_uncond,
|
1303
|
+
encoder_hidden_states_org,
|
1304
|
+
encoder_hidden_states_ptb,
|
1305
|
+
) = encoder_hidden_states.chunk(3)
|
1306
|
+
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
|
1371
1307
|
|
1372
|
-
|
1308
|
+
################## original path ##################
|
1309
|
+
batch_size = encoder_hidden_states_org.shape[0]
|
1373
1310
|
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1311
|
+
# `sample` projections.
|
1312
|
+
query_org = attn.to_q(hidden_states_org)
|
1313
|
+
key_org = attn.to_k(hidden_states_org)
|
1314
|
+
value_org = attn.to_v(hidden_states_org)
|
1378
1315
|
|
1379
|
-
|
1380
|
-
|
1316
|
+
# `context` projections.
|
1317
|
+
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
|
1318
|
+
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
|
1319
|
+
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
|
1381
1320
|
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1321
|
+
# attention
|
1322
|
+
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
|
1323
|
+
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
|
1324
|
+
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
|
1385
1325
|
|
1386
|
-
|
1387
|
-
|
1326
|
+
inner_dim = key_org.shape[-1]
|
1327
|
+
head_dim = inner_dim // attn.heads
|
1328
|
+
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1329
|
+
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1330
|
+
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1331
|
+
|
1332
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
1333
|
+
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
|
1334
|
+
)
|
1335
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1336
|
+
hidden_states_org = hidden_states_org.to(query_org.dtype)
|
1337
|
+
|
1338
|
+
# Split the attention outputs.
|
1339
|
+
hidden_states_org, encoder_hidden_states_org = (
|
1340
|
+
hidden_states_org[:, : residual.shape[1]],
|
1341
|
+
hidden_states_org[:, residual.shape[1] :],
|
1388
1342
|
)
|
1389
|
-
hidden_states = hidden_states.to(query.dtype)
|
1390
|
-
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1391
1343
|
|
1392
1344
|
# linear proj
|
1393
|
-
|
1345
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
1394
1346
|
# dropout
|
1395
|
-
|
1347
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
1348
|
+
if not attn.context_pre_only:
|
1349
|
+
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
|
1396
1350
|
|
1397
1351
|
if input_ndim == 4:
|
1398
|
-
|
1352
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1353
|
+
if context_input_ndim == 4:
|
1354
|
+
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
|
1355
|
+
batch_size, channel, height, width
|
1356
|
+
)
|
1399
1357
|
|
1400
|
-
|
1401
|
-
hidden_states = hidden_states + residual
|
1358
|
+
################## perturbed path ##################
|
1402
1359
|
|
1403
|
-
|
1360
|
+
batch_size = encoder_hidden_states_ptb.shape[0]
|
1404
1361
|
|
1405
|
-
|
1362
|
+
# `sample` projections.
|
1363
|
+
query_ptb = attn.to_q(hidden_states_ptb)
|
1364
|
+
key_ptb = attn.to_k(hidden_states_ptb)
|
1365
|
+
value_ptb = attn.to_v(hidden_states_ptb)
|
1406
1366
|
|
1367
|
+
# `context` projections.
|
1368
|
+
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
|
1369
|
+
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
|
1370
|
+
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
|
1407
1371
|
|
1408
|
-
|
1372
|
+
# attention
|
1373
|
+
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
|
1374
|
+
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
|
1375
|
+
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
|
1409
1376
|
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1377
|
+
inner_dim = key_ptb.shape[-1]
|
1378
|
+
head_dim = inner_dim // attn.heads
|
1379
|
+
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1380
|
+
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1381
|
+
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1414
1382
|
|
1415
|
-
|
1383
|
+
# create a full mask with all entries set to 0
|
1384
|
+
seq_len = query_ptb.size(2)
|
1385
|
+
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
|
1416
1386
|
|
1417
|
-
|
1418
|
-
|
1419
|
-
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
|
1387
|
+
# set the attention value between image patches to -inf
|
1388
|
+
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
|
1420
1389
|
|
1421
|
-
|
1422
|
-
|
1423
|
-
attn: Attention,
|
1424
|
-
hidden_states: torch.Tensor,
|
1425
|
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1426
|
-
attention_mask: Optional[torch.Tensor] = None,
|
1427
|
-
temb: Optional[torch.Tensor] = None,
|
1428
|
-
*args,
|
1429
|
-
**kwargs,
|
1430
|
-
) -> torch.Tensor:
|
1431
|
-
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1432
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1433
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
1390
|
+
# set the diagonal of the attention value between image patches to 0
|
1391
|
+
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
|
1434
1392
|
|
1435
|
-
|
1436
|
-
|
1437
|
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1393
|
+
# expand the mask to match the attention weights shape
|
1394
|
+
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
|
1438
1395
|
|
1439
|
-
|
1396
|
+
hidden_states_ptb = F.scaled_dot_product_attention(
|
1397
|
+
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
|
1398
|
+
)
|
1399
|
+
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1400
|
+
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
|
1401
|
+
|
1402
|
+
# split the attention outputs.
|
1403
|
+
hidden_states_ptb, encoder_hidden_states_ptb = (
|
1404
|
+
hidden_states_ptb[:, : residual.shape[1]],
|
1405
|
+
hidden_states_ptb[:, residual.shape[1] :],
|
1406
|
+
)
|
1407
|
+
|
1408
|
+
# linear proj
|
1409
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
1410
|
+
# dropout
|
1411
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
1412
|
+
if not attn.context_pre_only:
|
1413
|
+
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
|
1414
|
+
|
1415
|
+
if input_ndim == 4:
|
1416
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1417
|
+
if context_input_ndim == 4:
|
1418
|
+
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
|
1419
|
+
batch_size, channel, height, width
|
1420
|
+
)
|
1421
|
+
|
1422
|
+
################ concat ###############
|
1423
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
1424
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
|
1425
|
+
|
1426
|
+
return hidden_states, encoder_hidden_states
|
1427
|
+
|
1428
|
+
|
1429
|
+
class FusedJointAttnProcessor2_0:
|
1430
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1431
|
+
|
1432
|
+
def __init__(self):
|
1433
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1434
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1435
|
+
|
1436
|
+
def __call__(
|
1437
|
+
self,
|
1438
|
+
attn: Attention,
|
1439
|
+
hidden_states: torch.FloatTensor,
|
1440
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1441
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1442
|
+
*args,
|
1443
|
+
**kwargs,
|
1444
|
+
) -> torch.FloatTensor:
|
1445
|
+
residual = hidden_states
|
1446
|
+
|
1447
|
+
input_ndim = hidden_states.ndim
|
1448
|
+
if input_ndim == 4:
|
1449
|
+
batch_size, channel, height, width = hidden_states.shape
|
1450
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1451
|
+
context_input_ndim = encoder_hidden_states.ndim
|
1452
|
+
if context_input_ndim == 4:
|
1453
|
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
1454
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1455
|
+
|
1456
|
+
batch_size = encoder_hidden_states.shape[0]
|
1457
|
+
|
1458
|
+
# `sample` projections.
|
1459
|
+
qkv = attn.to_qkv(hidden_states)
|
1460
|
+
split_size = qkv.shape[-1] // 3
|
1461
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1462
|
+
|
1463
|
+
# `context` projections.
|
1464
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
1465
|
+
split_size = encoder_qkv.shape[-1] // 3
|
1466
|
+
(
|
1467
|
+
encoder_hidden_states_query_proj,
|
1468
|
+
encoder_hidden_states_key_proj,
|
1469
|
+
encoder_hidden_states_value_proj,
|
1470
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
1471
|
+
|
1472
|
+
# attention
|
1473
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1474
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1475
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1476
|
+
|
1477
|
+
inner_dim = key.shape[-1]
|
1478
|
+
head_dim = inner_dim // attn.heads
|
1479
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1480
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1481
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1482
|
+
|
1483
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1484
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1485
|
+
hidden_states = hidden_states.to(query.dtype)
|
1486
|
+
|
1487
|
+
# Split the attention outputs.
|
1488
|
+
hidden_states, encoder_hidden_states = (
|
1489
|
+
hidden_states[:, : residual.shape[1]],
|
1490
|
+
hidden_states[:, residual.shape[1] :],
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
# linear proj
|
1494
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1495
|
+
# dropout
|
1496
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1497
|
+
if not attn.context_pre_only:
|
1498
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1499
|
+
|
1500
|
+
if input_ndim == 4:
|
1501
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1502
|
+
if context_input_ndim == 4:
|
1503
|
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1504
|
+
|
1505
|
+
return hidden_states, encoder_hidden_states
|
1506
|
+
|
1507
|
+
|
1508
|
+
class AuraFlowAttnProcessor2_0:
|
1509
|
+
"""Attention processor used typically in processing Aura Flow."""
|
1510
|
+
|
1511
|
+
def __init__(self):
|
1512
|
+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
1513
|
+
raise ImportError(
|
1514
|
+
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
1515
|
+
)
|
1516
|
+
|
1517
|
+
def __call__(
|
1518
|
+
self,
|
1519
|
+
attn: Attention,
|
1520
|
+
hidden_states: torch.FloatTensor,
|
1521
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1522
|
+
*args,
|
1523
|
+
**kwargs,
|
1524
|
+
) -> torch.FloatTensor:
|
1525
|
+
batch_size = hidden_states.shape[0]
|
1526
|
+
|
1527
|
+
# `sample` projections.
|
1528
|
+
query = attn.to_q(hidden_states)
|
1529
|
+
key = attn.to_k(hidden_states)
|
1530
|
+
value = attn.to_v(hidden_states)
|
1531
|
+
|
1532
|
+
# `context` projections.
|
1533
|
+
if encoder_hidden_states is not None:
|
1534
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1535
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1536
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1537
|
+
|
1538
|
+
# Reshape.
|
1539
|
+
inner_dim = key.shape[-1]
|
1540
|
+
head_dim = inner_dim // attn.heads
|
1541
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
1542
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
1543
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
1544
|
+
|
1545
|
+
# Apply QK norm.
|
1546
|
+
if attn.norm_q is not None:
|
1547
|
+
query = attn.norm_q(query)
|
1548
|
+
if attn.norm_k is not None:
|
1549
|
+
key = attn.norm_k(key)
|
1550
|
+
|
1551
|
+
# Concatenate the projections.
|
1552
|
+
if encoder_hidden_states is not None:
|
1553
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1554
|
+
batch_size, -1, attn.heads, head_dim
|
1555
|
+
)
|
1556
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
1557
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1558
|
+
batch_size, -1, attn.heads, head_dim
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
if attn.norm_added_q is not None:
|
1562
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1563
|
+
if attn.norm_added_k is not None:
|
1564
|
+
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
1565
|
+
|
1566
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
1567
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1568
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1569
|
+
|
1570
|
+
query = query.transpose(1, 2)
|
1571
|
+
key = key.transpose(1, 2)
|
1572
|
+
value = value.transpose(1, 2)
|
1573
|
+
|
1574
|
+
# Attention.
|
1575
|
+
hidden_states = F.scaled_dot_product_attention(
|
1576
|
+
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
1577
|
+
)
|
1578
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1579
|
+
hidden_states = hidden_states.to(query.dtype)
|
1580
|
+
|
1581
|
+
# Split the attention outputs.
|
1582
|
+
if encoder_hidden_states is not None:
|
1583
|
+
hidden_states, encoder_hidden_states = (
|
1584
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
1585
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
1586
|
+
)
|
1587
|
+
|
1588
|
+
# linear proj
|
1589
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1590
|
+
# dropout
|
1591
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1592
|
+
if encoder_hidden_states is not None:
|
1593
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1594
|
+
|
1595
|
+
if encoder_hidden_states is not None:
|
1596
|
+
return hidden_states, encoder_hidden_states
|
1597
|
+
else:
|
1598
|
+
return hidden_states
|
1599
|
+
|
1600
|
+
|
1601
|
+
class FusedAuraFlowAttnProcessor2_0:
|
1602
|
+
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
1603
|
+
|
1604
|
+
def __init__(self):
|
1605
|
+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
1606
|
+
raise ImportError(
|
1607
|
+
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
1608
|
+
)
|
1609
|
+
|
1610
|
+
def __call__(
|
1611
|
+
self,
|
1612
|
+
attn: Attention,
|
1613
|
+
hidden_states: torch.FloatTensor,
|
1614
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1615
|
+
*args,
|
1616
|
+
**kwargs,
|
1617
|
+
) -> torch.FloatTensor:
|
1618
|
+
batch_size = hidden_states.shape[0]
|
1619
|
+
|
1620
|
+
# `sample` projections.
|
1621
|
+
qkv = attn.to_qkv(hidden_states)
|
1622
|
+
split_size = qkv.shape[-1] // 3
|
1623
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1624
|
+
|
1625
|
+
# `context` projections.
|
1626
|
+
if encoder_hidden_states is not None:
|
1627
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
1628
|
+
split_size = encoder_qkv.shape[-1] // 3
|
1629
|
+
(
|
1630
|
+
encoder_hidden_states_query_proj,
|
1631
|
+
encoder_hidden_states_key_proj,
|
1632
|
+
encoder_hidden_states_value_proj,
|
1633
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
1634
|
+
|
1635
|
+
# Reshape.
|
1636
|
+
inner_dim = key.shape[-1]
|
1637
|
+
head_dim = inner_dim // attn.heads
|
1638
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
1639
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
1640
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
1641
|
+
|
1642
|
+
# Apply QK norm.
|
1643
|
+
if attn.norm_q is not None:
|
1644
|
+
query = attn.norm_q(query)
|
1645
|
+
if attn.norm_k is not None:
|
1646
|
+
key = attn.norm_k(key)
|
1647
|
+
|
1648
|
+
# Concatenate the projections.
|
1649
|
+
if encoder_hidden_states is not None:
|
1650
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1651
|
+
batch_size, -1, attn.heads, head_dim
|
1652
|
+
)
|
1653
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
1654
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1655
|
+
batch_size, -1, attn.heads, head_dim
|
1656
|
+
)
|
1657
|
+
|
1658
|
+
if attn.norm_added_q is not None:
|
1659
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1660
|
+
if attn.norm_added_k is not None:
|
1661
|
+
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
1662
|
+
|
1663
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
1664
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1665
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1666
|
+
|
1667
|
+
query = query.transpose(1, 2)
|
1668
|
+
key = key.transpose(1, 2)
|
1669
|
+
value = value.transpose(1, 2)
|
1670
|
+
|
1671
|
+
# Attention.
|
1672
|
+
hidden_states = F.scaled_dot_product_attention(
|
1673
|
+
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
1674
|
+
)
|
1675
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1676
|
+
hidden_states = hidden_states.to(query.dtype)
|
1677
|
+
|
1678
|
+
# Split the attention outputs.
|
1679
|
+
if encoder_hidden_states is not None:
|
1680
|
+
hidden_states, encoder_hidden_states = (
|
1681
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
1682
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
1683
|
+
)
|
1684
|
+
|
1685
|
+
# linear proj
|
1686
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1687
|
+
# dropout
|
1688
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1689
|
+
if encoder_hidden_states is not None:
|
1690
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1691
|
+
|
1692
|
+
if encoder_hidden_states is not None:
|
1693
|
+
return hidden_states, encoder_hidden_states
|
1694
|
+
else:
|
1695
|
+
return hidden_states
|
1696
|
+
|
1697
|
+
|
1698
|
+
# YiYi to-do: refactor rope related functions/classes
|
1699
|
+
def apply_rope(xq, xk, freqs_cis):
|
1700
|
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
1701
|
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
1702
|
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
1703
|
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
1704
|
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
1705
|
+
|
1706
|
+
|
1707
|
+
class FluxSingleAttnProcessor2_0:
|
1708
|
+
r"""
|
1709
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1710
|
+
"""
|
1711
|
+
|
1712
|
+
def __init__(self):
|
1713
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1714
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1715
|
+
|
1716
|
+
def __call__(
|
1717
|
+
self,
|
1718
|
+
attn: Attention,
|
1719
|
+
hidden_states: torch.Tensor,
|
1720
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1721
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1722
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1723
|
+
) -> torch.Tensor:
|
1724
|
+
input_ndim = hidden_states.ndim
|
1725
|
+
|
1726
|
+
if input_ndim == 4:
|
1727
|
+
batch_size, channel, height, width = hidden_states.shape
|
1728
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1729
|
+
|
1730
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1731
|
+
|
1732
|
+
query = attn.to_q(hidden_states)
|
1733
|
+
if encoder_hidden_states is None:
|
1734
|
+
encoder_hidden_states = hidden_states
|
1735
|
+
|
1736
|
+
key = attn.to_k(encoder_hidden_states)
|
1737
|
+
value = attn.to_v(encoder_hidden_states)
|
1738
|
+
|
1739
|
+
inner_dim = key.shape[-1]
|
1740
|
+
head_dim = inner_dim // attn.heads
|
1741
|
+
|
1742
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1743
|
+
|
1744
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1745
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1746
|
+
|
1747
|
+
if attn.norm_q is not None:
|
1748
|
+
query = attn.norm_q(query)
|
1749
|
+
if attn.norm_k is not None:
|
1750
|
+
key = attn.norm_k(key)
|
1751
|
+
|
1752
|
+
# Apply RoPE if needed
|
1753
|
+
if image_rotary_emb is not None:
|
1754
|
+
# YiYi to-do: update uising apply_rotary_emb
|
1755
|
+
# from ..embeddings import apply_rotary_emb
|
1756
|
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
1757
|
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
1758
|
+
query, key = apply_rope(query, key, image_rotary_emb)
|
1759
|
+
|
1760
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1761
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1762
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1763
|
+
|
1764
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1765
|
+
hidden_states = hidden_states.to(query.dtype)
|
1766
|
+
|
1767
|
+
if input_ndim == 4:
|
1768
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1769
|
+
|
1770
|
+
return hidden_states
|
1771
|
+
|
1772
|
+
|
1773
|
+
class FluxAttnProcessor2_0:
|
1774
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
1775
|
+
|
1776
|
+
def __init__(self):
|
1777
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1778
|
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1779
|
+
|
1780
|
+
def __call__(
|
1781
|
+
self,
|
1782
|
+
attn: Attention,
|
1783
|
+
hidden_states: torch.FloatTensor,
|
1784
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
1785
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1786
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1787
|
+
) -> torch.FloatTensor:
|
1788
|
+
input_ndim = hidden_states.ndim
|
1789
|
+
if input_ndim == 4:
|
1790
|
+
batch_size, channel, height, width = hidden_states.shape
|
1791
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1792
|
+
context_input_ndim = encoder_hidden_states.ndim
|
1793
|
+
if context_input_ndim == 4:
|
1794
|
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
1795
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1796
|
+
|
1797
|
+
batch_size = encoder_hidden_states.shape[0]
|
1798
|
+
|
1799
|
+
# `sample` projections.
|
1800
|
+
query = attn.to_q(hidden_states)
|
1801
|
+
key = attn.to_k(hidden_states)
|
1802
|
+
value = attn.to_v(hidden_states)
|
1803
|
+
|
1804
|
+
inner_dim = key.shape[-1]
|
1805
|
+
head_dim = inner_dim // attn.heads
|
1806
|
+
|
1807
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1808
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1809
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1810
|
+
|
1811
|
+
if attn.norm_q is not None:
|
1812
|
+
query = attn.norm_q(query)
|
1813
|
+
if attn.norm_k is not None:
|
1814
|
+
key = attn.norm_k(key)
|
1815
|
+
|
1816
|
+
# `context` projections.
|
1817
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1818
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1819
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1820
|
+
|
1821
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1822
|
+
batch_size, -1, attn.heads, head_dim
|
1823
|
+
).transpose(1, 2)
|
1824
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
1825
|
+
batch_size, -1, attn.heads, head_dim
|
1826
|
+
).transpose(1, 2)
|
1827
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
1828
|
+
batch_size, -1, attn.heads, head_dim
|
1829
|
+
).transpose(1, 2)
|
1830
|
+
|
1831
|
+
if attn.norm_added_q is not None:
|
1832
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1833
|
+
if attn.norm_added_k is not None:
|
1834
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1835
|
+
|
1836
|
+
# attention
|
1837
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
1838
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
1839
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1840
|
+
|
1841
|
+
if image_rotary_emb is not None:
|
1842
|
+
# YiYi to-do: update uising apply_rotary_emb
|
1843
|
+
# from ..embeddings import apply_rotary_emb
|
1844
|
+
# query = apply_rotary_emb(query, image_rotary_emb)
|
1845
|
+
# key = apply_rotary_emb(key, image_rotary_emb)
|
1846
|
+
query, key = apply_rope(query, key, image_rotary_emb)
|
1847
|
+
|
1848
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
1849
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1850
|
+
hidden_states = hidden_states.to(query.dtype)
|
1851
|
+
|
1852
|
+
encoder_hidden_states, hidden_states = (
|
1853
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
1854
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
1855
|
+
)
|
1856
|
+
|
1857
|
+
# linear proj
|
1858
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1859
|
+
# dropout
|
1860
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1861
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1862
|
+
|
1863
|
+
if input_ndim == 4:
|
1864
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1865
|
+
if context_input_ndim == 4:
|
1866
|
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1867
|
+
|
1868
|
+
return hidden_states, encoder_hidden_states
|
1869
|
+
|
1870
|
+
|
1871
|
+
class CogVideoXAttnProcessor2_0:
|
1872
|
+
r"""
|
1873
|
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
1874
|
+
query and key vectors, but does not include spatial normalization.
|
1875
|
+
"""
|
1876
|
+
|
1877
|
+
def __init__(self):
|
1878
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1879
|
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1880
|
+
|
1881
|
+
def __call__(
|
1882
|
+
self,
|
1883
|
+
attn: Attention,
|
1884
|
+
hidden_states: torch.Tensor,
|
1885
|
+
encoder_hidden_states: torch.Tensor,
|
1886
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1887
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1888
|
+
) -> torch.Tensor:
|
1889
|
+
text_seq_length = encoder_hidden_states.size(1)
|
1890
|
+
|
1891
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1892
|
+
|
1893
|
+
batch_size, sequence_length, _ = (
|
1894
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1895
|
+
)
|
1896
|
+
|
1897
|
+
if attention_mask is not None:
|
1898
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1899
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1900
|
+
|
1901
|
+
query = attn.to_q(hidden_states)
|
1902
|
+
key = attn.to_k(hidden_states)
|
1903
|
+
value = attn.to_v(hidden_states)
|
1904
|
+
|
1905
|
+
inner_dim = key.shape[-1]
|
1906
|
+
head_dim = inner_dim // attn.heads
|
1907
|
+
|
1908
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1909
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1910
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1911
|
+
|
1912
|
+
if attn.norm_q is not None:
|
1913
|
+
query = attn.norm_q(query)
|
1914
|
+
if attn.norm_k is not None:
|
1915
|
+
key = attn.norm_k(key)
|
1916
|
+
|
1917
|
+
# Apply RoPE if needed
|
1918
|
+
if image_rotary_emb is not None:
|
1919
|
+
from .embeddings import apply_rotary_emb
|
1920
|
+
|
1921
|
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
1922
|
+
if not attn.is_cross_attention:
|
1923
|
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
1924
|
+
|
1925
|
+
hidden_states = F.scaled_dot_product_attention(
|
1926
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1927
|
+
)
|
1928
|
+
|
1929
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1930
|
+
|
1931
|
+
# linear proj
|
1932
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1933
|
+
# dropout
|
1934
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1935
|
+
|
1936
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
1937
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
1938
|
+
)
|
1939
|
+
return hidden_states, encoder_hidden_states
|
1940
|
+
|
1941
|
+
|
1942
|
+
class FusedCogVideoXAttnProcessor2_0:
|
1943
|
+
r"""
|
1944
|
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
1945
|
+
query and key vectors, but does not include spatial normalization.
|
1946
|
+
"""
|
1947
|
+
|
1948
|
+
def __init__(self):
|
1949
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1950
|
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1951
|
+
|
1952
|
+
def __call__(
|
1953
|
+
self,
|
1954
|
+
attn: Attention,
|
1955
|
+
hidden_states: torch.Tensor,
|
1956
|
+
encoder_hidden_states: torch.Tensor,
|
1957
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1958
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1959
|
+
) -> torch.Tensor:
|
1960
|
+
text_seq_length = encoder_hidden_states.size(1)
|
1961
|
+
|
1962
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
1963
|
+
|
1964
|
+
batch_size, sequence_length, _ = (
|
1965
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1966
|
+
)
|
1967
|
+
|
1968
|
+
if attention_mask is not None:
|
1969
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1970
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1971
|
+
|
1972
|
+
qkv = attn.to_qkv(hidden_states)
|
1973
|
+
split_size = qkv.shape[-1] // 3
|
1974
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1975
|
+
|
1976
|
+
inner_dim = key.shape[-1]
|
1977
|
+
head_dim = inner_dim // attn.heads
|
1978
|
+
|
1979
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1980
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1981
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1982
|
+
|
1983
|
+
if attn.norm_q is not None:
|
1984
|
+
query = attn.norm_q(query)
|
1985
|
+
if attn.norm_k is not None:
|
1986
|
+
key = attn.norm_k(key)
|
1987
|
+
|
1988
|
+
# Apply RoPE if needed
|
1989
|
+
if image_rotary_emb is not None:
|
1990
|
+
from .embeddings import apply_rotary_emb
|
1991
|
+
|
1992
|
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
1993
|
+
if not attn.is_cross_attention:
|
1994
|
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
1995
|
+
|
1996
|
+
hidden_states = F.scaled_dot_product_attention(
|
1997
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1998
|
+
)
|
1999
|
+
|
2000
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2001
|
+
|
2002
|
+
# linear proj
|
2003
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2004
|
+
# dropout
|
2005
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2006
|
+
|
2007
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
2008
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
2009
|
+
)
|
2010
|
+
return hidden_states, encoder_hidden_states
|
2011
|
+
|
2012
|
+
|
2013
|
+
class XFormersAttnAddedKVProcessor:
|
2014
|
+
r"""
|
2015
|
+
Processor for implementing memory efficient attention using xFormers.
|
2016
|
+
|
2017
|
+
Args:
|
2018
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
2019
|
+
The base
|
2020
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
2021
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
2022
|
+
operator.
|
2023
|
+
"""
|
2024
|
+
|
2025
|
+
def __init__(self, attention_op: Optional[Callable] = None):
|
2026
|
+
self.attention_op = attention_op
|
2027
|
+
|
2028
|
+
def __call__(
|
2029
|
+
self,
|
2030
|
+
attn: Attention,
|
2031
|
+
hidden_states: torch.Tensor,
|
2032
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2033
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2034
|
+
) -> torch.Tensor:
|
2035
|
+
residual = hidden_states
|
2036
|
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
2037
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
2038
|
+
|
2039
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2040
|
+
|
2041
|
+
if encoder_hidden_states is None:
|
2042
|
+
encoder_hidden_states = hidden_states
|
2043
|
+
elif attn.norm_cross:
|
2044
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2045
|
+
|
2046
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2047
|
+
|
2048
|
+
query = attn.to_q(hidden_states)
|
2049
|
+
query = attn.head_to_batch_dim(query)
|
2050
|
+
|
2051
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2052
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2053
|
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
2054
|
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
2055
|
+
|
2056
|
+
if not attn.only_cross_attention:
|
2057
|
+
key = attn.to_k(hidden_states)
|
2058
|
+
value = attn.to_v(hidden_states)
|
2059
|
+
key = attn.head_to_batch_dim(key)
|
2060
|
+
value = attn.head_to_batch_dim(value)
|
2061
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
2062
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
2063
|
+
else:
|
2064
|
+
key = encoder_hidden_states_key_proj
|
2065
|
+
value = encoder_hidden_states_value_proj
|
2066
|
+
|
2067
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
2068
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
2069
|
+
)
|
2070
|
+
hidden_states = hidden_states.to(query.dtype)
|
2071
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
2072
|
+
|
2073
|
+
# linear proj
|
2074
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2075
|
+
# dropout
|
2076
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2077
|
+
|
2078
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
2079
|
+
hidden_states = hidden_states + residual
|
2080
|
+
|
2081
|
+
return hidden_states
|
2082
|
+
|
2083
|
+
|
2084
|
+
class XFormersAttnProcessor:
|
2085
|
+
r"""
|
2086
|
+
Processor for implementing memory efficient attention using xFormers.
|
2087
|
+
|
2088
|
+
Args:
|
2089
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
2090
|
+
The base
|
2091
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
2092
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
2093
|
+
operator.
|
2094
|
+
"""
|
2095
|
+
|
2096
|
+
def __init__(self, attention_op: Optional[Callable] = None):
|
2097
|
+
self.attention_op = attention_op
|
2098
|
+
|
2099
|
+
def __call__(
|
2100
|
+
self,
|
2101
|
+
attn: Attention,
|
2102
|
+
hidden_states: torch.Tensor,
|
2103
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2104
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2105
|
+
temb: Optional[torch.Tensor] = None,
|
2106
|
+
*args,
|
2107
|
+
**kwargs,
|
2108
|
+
) -> torch.Tensor:
|
2109
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2110
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2111
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2112
|
+
|
2113
|
+
residual = hidden_states
|
2114
|
+
|
2115
|
+
if attn.spatial_norm is not None:
|
2116
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2117
|
+
|
2118
|
+
input_ndim = hidden_states.ndim
|
2119
|
+
|
2120
|
+
if input_ndim == 4:
|
2121
|
+
batch_size, channel, height, width = hidden_states.shape
|
2122
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2123
|
+
|
2124
|
+
batch_size, key_tokens, _ = (
|
2125
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2126
|
+
)
|
2127
|
+
|
2128
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
2129
|
+
if attention_mask is not None:
|
2130
|
+
# expand our mask's singleton query_tokens dimension:
|
2131
|
+
# [batch*heads, 1, key_tokens] ->
|
2132
|
+
# [batch*heads, query_tokens, key_tokens]
|
2133
|
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
2134
|
+
# [batch*heads, query_tokens, key_tokens]
|
2135
|
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
2136
|
+
_, query_tokens, _ = hidden_states.shape
|
2137
|
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
2138
|
+
|
2139
|
+
if attn.group_norm is not None:
|
2140
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2141
|
+
|
2142
|
+
query = attn.to_q(hidden_states)
|
2143
|
+
|
2144
|
+
if encoder_hidden_states is None:
|
2145
|
+
encoder_hidden_states = hidden_states
|
2146
|
+
elif attn.norm_cross:
|
2147
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2148
|
+
|
2149
|
+
key = attn.to_k(encoder_hidden_states)
|
2150
|
+
value = attn.to_v(encoder_hidden_states)
|
2151
|
+
|
2152
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
2153
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
2154
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
2155
|
+
|
2156
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
2157
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
2158
|
+
)
|
2159
|
+
hidden_states = hidden_states.to(query.dtype)
|
2160
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
2161
|
+
|
2162
|
+
# linear proj
|
2163
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2164
|
+
# dropout
|
2165
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2166
|
+
|
2167
|
+
if input_ndim == 4:
|
2168
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2169
|
+
|
2170
|
+
if attn.residual_connection:
|
2171
|
+
hidden_states = hidden_states + residual
|
2172
|
+
|
2173
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2174
|
+
|
2175
|
+
return hidden_states
|
2176
|
+
|
2177
|
+
|
2178
|
+
class AttnProcessorNPU:
|
2179
|
+
r"""
|
2180
|
+
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
|
2181
|
+
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
|
2182
|
+
not significant.
|
2183
|
+
|
2184
|
+
"""
|
2185
|
+
|
2186
|
+
def __init__(self):
|
2187
|
+
if not is_torch_npu_available():
|
2188
|
+
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
|
2189
|
+
|
2190
|
+
def __call__(
|
2191
|
+
self,
|
2192
|
+
attn: Attention,
|
2193
|
+
hidden_states: torch.Tensor,
|
2194
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2195
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2196
|
+
temb: Optional[torch.Tensor] = None,
|
2197
|
+
*args,
|
2198
|
+
**kwargs,
|
2199
|
+
) -> torch.Tensor:
|
2200
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2201
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2202
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2203
|
+
|
2204
|
+
residual = hidden_states
|
2205
|
+
if attn.spatial_norm is not None:
|
2206
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2207
|
+
|
2208
|
+
input_ndim = hidden_states.ndim
|
2209
|
+
|
2210
|
+
if input_ndim == 4:
|
2211
|
+
batch_size, channel, height, width = hidden_states.shape
|
2212
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2213
|
+
|
2214
|
+
batch_size, sequence_length, _ = (
|
2215
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2216
|
+
)
|
2217
|
+
|
2218
|
+
if attention_mask is not None:
|
2219
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2220
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2221
|
+
# (batch, heads, source_length, target_length)
|
2222
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2223
|
+
|
2224
|
+
if attn.group_norm is not None:
|
2225
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2226
|
+
|
2227
|
+
query = attn.to_q(hidden_states)
|
2228
|
+
|
2229
|
+
if encoder_hidden_states is None:
|
2230
|
+
encoder_hidden_states = hidden_states
|
2231
|
+
elif attn.norm_cross:
|
2232
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2233
|
+
|
2234
|
+
key = attn.to_k(encoder_hidden_states)
|
2235
|
+
value = attn.to_v(encoder_hidden_states)
|
2236
|
+
|
2237
|
+
inner_dim = key.shape[-1]
|
2238
|
+
head_dim = inner_dim // attn.heads
|
2239
|
+
|
2240
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2241
|
+
|
2242
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2243
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2244
|
+
|
2245
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2246
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
2247
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
2248
|
+
query,
|
2249
|
+
key,
|
2250
|
+
value,
|
2251
|
+
attn.heads,
|
2252
|
+
input_layout="BNSD",
|
2253
|
+
pse=None,
|
2254
|
+
atten_mask=attention_mask,
|
2255
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
2256
|
+
pre_tockens=65536,
|
2257
|
+
next_tockens=65536,
|
2258
|
+
keep_prob=1.0,
|
2259
|
+
sync=False,
|
2260
|
+
inner_precise=0,
|
2261
|
+
)[0]
|
2262
|
+
else:
|
2263
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2264
|
+
hidden_states = F.scaled_dot_product_attention(
|
2265
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2266
|
+
)
|
2267
|
+
|
2268
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2269
|
+
hidden_states = hidden_states.to(query.dtype)
|
2270
|
+
|
2271
|
+
# linear proj
|
2272
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2273
|
+
# dropout
|
2274
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2275
|
+
|
2276
|
+
if input_ndim == 4:
|
2277
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2278
|
+
|
2279
|
+
if attn.residual_connection:
|
2280
|
+
hidden_states = hidden_states + residual
|
2281
|
+
|
2282
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2283
|
+
|
2284
|
+
return hidden_states
|
2285
|
+
|
2286
|
+
|
2287
|
+
class AttnProcessor2_0:
|
2288
|
+
r"""
|
2289
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
2290
|
+
"""
|
2291
|
+
|
2292
|
+
def __init__(self):
|
2293
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2294
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2295
|
+
|
2296
|
+
def __call__(
|
2297
|
+
self,
|
2298
|
+
attn: Attention,
|
2299
|
+
hidden_states: torch.Tensor,
|
2300
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2301
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2302
|
+
temb: Optional[torch.Tensor] = None,
|
2303
|
+
*args,
|
2304
|
+
**kwargs,
|
2305
|
+
) -> torch.Tensor:
|
2306
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2307
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2308
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2309
|
+
|
2310
|
+
residual = hidden_states
|
2311
|
+
if attn.spatial_norm is not None:
|
2312
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2313
|
+
|
2314
|
+
input_ndim = hidden_states.ndim
|
2315
|
+
|
2316
|
+
if input_ndim == 4:
|
2317
|
+
batch_size, channel, height, width = hidden_states.shape
|
2318
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2319
|
+
|
2320
|
+
batch_size, sequence_length, _ = (
|
2321
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2322
|
+
)
|
2323
|
+
|
2324
|
+
if attention_mask is not None:
|
2325
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2326
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2327
|
+
# (batch, heads, source_length, target_length)
|
2328
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2329
|
+
|
2330
|
+
if attn.group_norm is not None:
|
2331
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2332
|
+
|
2333
|
+
query = attn.to_q(hidden_states)
|
2334
|
+
|
2335
|
+
if encoder_hidden_states is None:
|
2336
|
+
encoder_hidden_states = hidden_states
|
2337
|
+
elif attn.norm_cross:
|
2338
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2339
|
+
|
2340
|
+
key = attn.to_k(encoder_hidden_states)
|
2341
|
+
value = attn.to_v(encoder_hidden_states)
|
2342
|
+
|
2343
|
+
inner_dim = key.shape[-1]
|
2344
|
+
head_dim = inner_dim // attn.heads
|
2345
|
+
|
2346
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2347
|
+
|
2348
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2349
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2350
|
+
|
2351
|
+
if attn.norm_q is not None:
|
2352
|
+
query = attn.norm_q(query)
|
2353
|
+
if attn.norm_k is not None:
|
2354
|
+
key = attn.norm_k(key)
|
2355
|
+
|
2356
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2357
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2358
|
+
hidden_states = F.scaled_dot_product_attention(
|
2359
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2360
|
+
)
|
2361
|
+
|
2362
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2363
|
+
hidden_states = hidden_states.to(query.dtype)
|
2364
|
+
|
2365
|
+
# linear proj
|
2366
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2367
|
+
# dropout
|
2368
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2369
|
+
|
2370
|
+
if input_ndim == 4:
|
2371
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2372
|
+
|
2373
|
+
if attn.residual_connection:
|
2374
|
+
hidden_states = hidden_states + residual
|
2375
|
+
|
2376
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2377
|
+
|
2378
|
+
return hidden_states
|
2379
|
+
|
2380
|
+
|
2381
|
+
class StableAudioAttnProcessor2_0:
|
2382
|
+
r"""
|
2383
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
2384
|
+
used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
|
2385
|
+
"""
|
2386
|
+
|
2387
|
+
def __init__(self):
|
2388
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2389
|
+
raise ImportError(
|
2390
|
+
"StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2391
|
+
)
|
2392
|
+
|
2393
|
+
def apply_partial_rotary_emb(
|
2394
|
+
self,
|
2395
|
+
x: torch.Tensor,
|
2396
|
+
freqs_cis: Tuple[torch.Tensor],
|
2397
|
+
) -> torch.Tensor:
|
2398
|
+
from .embeddings import apply_rotary_emb
|
2399
|
+
|
2400
|
+
rot_dim = freqs_cis[0].shape[-1]
|
2401
|
+
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
|
2402
|
+
|
2403
|
+
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
|
2404
|
+
|
2405
|
+
out = torch.cat((x_rotated, x_unrotated), dim=-1)
|
2406
|
+
return out
|
2407
|
+
|
2408
|
+
def __call__(
|
2409
|
+
self,
|
2410
|
+
attn: Attention,
|
2411
|
+
hidden_states: torch.Tensor,
|
2412
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2413
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2414
|
+
rotary_emb: Optional[torch.Tensor] = None,
|
2415
|
+
) -> torch.Tensor:
|
2416
|
+
from .embeddings import apply_rotary_emb
|
2417
|
+
|
2418
|
+
residual = hidden_states
|
2419
|
+
|
2420
|
+
input_ndim = hidden_states.ndim
|
2421
|
+
|
2422
|
+
if input_ndim == 4:
|
2423
|
+
batch_size, channel, height, width = hidden_states.shape
|
2424
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2425
|
+
|
2426
|
+
batch_size, sequence_length, _ = (
|
2427
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2428
|
+
)
|
2429
|
+
|
2430
|
+
if attention_mask is not None:
|
2431
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2432
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2433
|
+
# (batch, heads, source_length, target_length)
|
2434
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2435
|
+
|
2436
|
+
query = attn.to_q(hidden_states)
|
2437
|
+
|
2438
|
+
if encoder_hidden_states is None:
|
2439
|
+
encoder_hidden_states = hidden_states
|
2440
|
+
elif attn.norm_cross:
|
2441
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2442
|
+
|
2443
|
+
key = attn.to_k(encoder_hidden_states)
|
2444
|
+
value = attn.to_v(encoder_hidden_states)
|
2445
|
+
|
2446
|
+
head_dim = query.shape[-1] // attn.heads
|
2447
|
+
kv_heads = key.shape[-1] // head_dim
|
2448
|
+
|
2449
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2450
|
+
|
2451
|
+
key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
2452
|
+
value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
|
2453
|
+
|
2454
|
+
if kv_heads != attn.heads:
|
2455
|
+
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
2456
|
+
heads_per_kv_head = attn.heads // kv_heads
|
2457
|
+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
2458
|
+
value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
|
2459
|
+
|
2460
|
+
if attn.norm_q is not None:
|
2461
|
+
query = attn.norm_q(query)
|
2462
|
+
if attn.norm_k is not None:
|
2463
|
+
key = attn.norm_k(key)
|
2464
|
+
|
2465
|
+
# Apply RoPE if needed
|
2466
|
+
if rotary_emb is not None:
|
2467
|
+
query_dtype = query.dtype
|
2468
|
+
key_dtype = key.dtype
|
2469
|
+
query = query.to(torch.float32)
|
2470
|
+
key = key.to(torch.float32)
|
2471
|
+
|
2472
|
+
rot_dim = rotary_emb[0].shape[-1]
|
2473
|
+
query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
|
2474
|
+
query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
2475
|
+
|
2476
|
+
query = torch.cat((query_rotated, query_unrotated), dim=-1)
|
2477
|
+
|
2478
|
+
if not attn.is_cross_attention:
|
2479
|
+
key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
|
2480
|
+
key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
|
2481
|
+
|
2482
|
+
key = torch.cat((key_rotated, key_unrotated), dim=-1)
|
2483
|
+
|
2484
|
+
query = query.to(query_dtype)
|
2485
|
+
key = key.to(key_dtype)
|
2486
|
+
|
2487
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2488
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2489
|
+
hidden_states = F.scaled_dot_product_attention(
|
2490
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2491
|
+
)
|
2492
|
+
|
2493
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2494
|
+
hidden_states = hidden_states.to(query.dtype)
|
2495
|
+
|
2496
|
+
# linear proj
|
2497
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2498
|
+
# dropout
|
2499
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2500
|
+
|
2501
|
+
if input_ndim == 4:
|
2502
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2503
|
+
|
2504
|
+
if attn.residual_connection:
|
2505
|
+
hidden_states = hidden_states + residual
|
2506
|
+
|
2507
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2508
|
+
|
2509
|
+
return hidden_states
|
2510
|
+
|
2511
|
+
|
2512
|
+
class HunyuanAttnProcessor2_0:
|
2513
|
+
r"""
|
2514
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
2515
|
+
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
2516
|
+
"""
|
2517
|
+
|
2518
|
+
def __init__(self):
|
2519
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2520
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2521
|
+
|
2522
|
+
def __call__(
|
2523
|
+
self,
|
2524
|
+
attn: Attention,
|
2525
|
+
hidden_states: torch.Tensor,
|
2526
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2527
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2528
|
+
temb: Optional[torch.Tensor] = None,
|
2529
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2530
|
+
) -> torch.Tensor:
|
2531
|
+
from .embeddings import apply_rotary_emb
|
2532
|
+
|
2533
|
+
residual = hidden_states
|
2534
|
+
if attn.spatial_norm is not None:
|
2535
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2536
|
+
|
2537
|
+
input_ndim = hidden_states.ndim
|
1440
2538
|
|
1441
2539
|
if input_ndim == 4:
|
1442
2540
|
batch_size, channel, height, width = hidden_states.shape
|
@@ -1473,28 +2571,22 @@ class AttnProcessorNPU:
|
|
1473
2571
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1474
2572
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1475
2573
|
|
2574
|
+
if attn.norm_q is not None:
|
2575
|
+
query = attn.norm_q(query)
|
2576
|
+
if attn.norm_k is not None:
|
2577
|
+
key = attn.norm_k(key)
|
2578
|
+
|
2579
|
+
# Apply RoPE if needed
|
2580
|
+
if image_rotary_emb is not None:
|
2581
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2582
|
+
if not attn.is_cross_attention:
|
2583
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2584
|
+
|
1476
2585
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
value,
|
1482
|
-
attn.heads,
|
1483
|
-
input_layout="BNSD",
|
1484
|
-
pse=None,
|
1485
|
-
atten_mask=attention_mask,
|
1486
|
-
scale=1.0 / math.sqrt(query.shape[-1]),
|
1487
|
-
pre_tockens=65536,
|
1488
|
-
next_tockens=65536,
|
1489
|
-
keep_prob=1.0,
|
1490
|
-
sync=False,
|
1491
|
-
inner_precise=0,
|
1492
|
-
)[0]
|
1493
|
-
else:
|
1494
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
1495
|
-
hidden_states = F.scaled_dot_product_attention(
|
1496
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1497
|
-
)
|
2586
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2587
|
+
hidden_states = F.scaled_dot_product_attention(
|
2588
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2589
|
+
)
|
1498
2590
|
|
1499
2591
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1500
2592
|
hidden_states = hidden_states.to(query.dtype)
|
@@ -1515,14 +2607,18 @@ class AttnProcessorNPU:
|
|
1515
2607
|
return hidden_states
|
1516
2608
|
|
1517
2609
|
|
1518
|
-
class
|
2610
|
+
class FusedHunyuanAttnProcessor2_0:
|
1519
2611
|
r"""
|
1520
|
-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0)
|
2612
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
|
2613
|
+
projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
|
2614
|
+
query and key vector.
|
1521
2615
|
"""
|
1522
2616
|
|
1523
2617
|
def __init__(self):
|
1524
2618
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1525
|
-
raise ImportError(
|
2619
|
+
raise ImportError(
|
2620
|
+
"FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2621
|
+
)
|
1526
2622
|
|
1527
2623
|
def __call__(
|
1528
2624
|
self,
|
@@ -1531,12 +2627,9 @@ class AttnProcessor2_0:
|
|
1531
2627
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1532
2628
|
attention_mask: Optional[torch.Tensor] = None,
|
1533
2629
|
temb: Optional[torch.Tensor] = None,
|
1534
|
-
|
1535
|
-
**kwargs,
|
2630
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1536
2631
|
) -> torch.Tensor:
|
1537
|
-
|
1538
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1539
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
2632
|
+
from .embeddings import apply_rotary_emb
|
1540
2633
|
|
1541
2634
|
residual = hidden_states
|
1542
2635
|
if attn.spatial_norm is not None:
|
@@ -1561,24 +2654,37 @@ class AttnProcessor2_0:
|
|
1561
2654
|
if attn.group_norm is not None:
|
1562
2655
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1563
2656
|
|
1564
|
-
query = attn.to_q(hidden_states)
|
1565
|
-
|
1566
2657
|
if encoder_hidden_states is None:
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
2658
|
+
qkv = attn.to_qkv(hidden_states)
|
2659
|
+
split_size = qkv.shape[-1] // 3
|
2660
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2661
|
+
else:
|
2662
|
+
if attn.norm_cross:
|
2663
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2664
|
+
query = attn.to_q(hidden_states)
|
1570
2665
|
|
1571
|
-
|
1572
|
-
|
2666
|
+
kv = attn.to_kv(encoder_hidden_states)
|
2667
|
+
split_size = kv.shape[-1] // 2
|
2668
|
+
key, value = torch.split(kv, split_size, dim=-1)
|
1573
2669
|
|
1574
2670
|
inner_dim = key.shape[-1]
|
1575
2671
|
head_dim = inner_dim // attn.heads
|
1576
2672
|
|
1577
2673
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1578
|
-
|
1579
2674
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1580
2675
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1581
2676
|
|
2677
|
+
if attn.norm_q is not None:
|
2678
|
+
query = attn.norm_q(query)
|
2679
|
+
if attn.norm_k is not None:
|
2680
|
+
key = attn.norm_k(key)
|
2681
|
+
|
2682
|
+
# Apply RoPE if needed
|
2683
|
+
if image_rotary_emb is not None:
|
2684
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2685
|
+
if not attn.is_cross_attention:
|
2686
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2687
|
+
|
1582
2688
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1583
2689
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
1584
2690
|
hidden_states = F.scaled_dot_product_attention(
|
@@ -1604,15 +2710,18 @@ class AttnProcessor2_0:
|
|
1604
2710
|
return hidden_states
|
1605
2711
|
|
1606
2712
|
|
1607
|
-
class
|
2713
|
+
class PAGHunyuanAttnProcessor2_0:
|
1608
2714
|
r"""
|
1609
2715
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
1610
|
-
used in the HunyuanDiT model. It applies a
|
2716
|
+
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
|
2717
|
+
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
|
1611
2718
|
"""
|
1612
2719
|
|
1613
2720
|
def __init__(self):
|
1614
2721
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1615
|
-
raise ImportError(
|
2722
|
+
raise ImportError(
|
2723
|
+
"PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2724
|
+
)
|
1616
2725
|
|
1617
2726
|
def __call__(
|
1618
2727
|
self,
|
@@ -1635,8 +2744,12 @@ class HunyuanAttnProcessor2_0:
|
|
1635
2744
|
batch_size, channel, height, width = hidden_states.shape
|
1636
2745
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1637
2746
|
|
2747
|
+
# chunk
|
2748
|
+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
2749
|
+
|
2750
|
+
# 1. Original Path
|
1638
2751
|
batch_size, sequence_length, _ = (
|
1639
|
-
|
2752
|
+
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1640
2753
|
)
|
1641
2754
|
|
1642
2755
|
if attention_mask is not None:
|
@@ -1646,12 +2759,12 @@ class HunyuanAttnProcessor2_0:
|
|
1646
2759
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1647
2760
|
|
1648
2761
|
if attn.group_norm is not None:
|
1649
|
-
|
2762
|
+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
1650
2763
|
|
1651
|
-
query = attn.to_q(
|
2764
|
+
query = attn.to_q(hidden_states_org)
|
1652
2765
|
|
1653
2766
|
if encoder_hidden_states is None:
|
1654
|
-
encoder_hidden_states =
|
2767
|
+
encoder_hidden_states = hidden_states_org
|
1655
2768
|
elif attn.norm_cross:
|
1656
2769
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1657
2770
|
|
@@ -1679,25 +2792,263 @@ class HunyuanAttnProcessor2_0:
|
|
1679
2792
|
|
1680
2793
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1681
2794
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
1682
|
-
|
2795
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
1683
2796
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1684
2797
|
)
|
1685
2798
|
|
1686
|
-
|
1687
|
-
|
2799
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2800
|
+
hidden_states_org = hidden_states_org.to(query.dtype)
|
1688
2801
|
|
1689
2802
|
# linear proj
|
1690
|
-
|
2803
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
1691
2804
|
# dropout
|
1692
|
-
|
2805
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
2806
|
+
|
2807
|
+
if input_ndim == 4:
|
2808
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2809
|
+
|
2810
|
+
# 2. Perturbed Path
|
2811
|
+
if attn.group_norm is not None:
|
2812
|
+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
2813
|
+
|
2814
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb)
|
2815
|
+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
2816
|
+
|
2817
|
+
# linear proj
|
2818
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
2819
|
+
# dropout
|
2820
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
2821
|
+
|
2822
|
+
if input_ndim == 4:
|
2823
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2824
|
+
|
2825
|
+
# cat
|
2826
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
2827
|
+
|
2828
|
+
if attn.residual_connection:
|
2829
|
+
hidden_states = hidden_states + residual
|
2830
|
+
|
2831
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2832
|
+
|
2833
|
+
return hidden_states
|
2834
|
+
|
2835
|
+
|
2836
|
+
class PAGCFGHunyuanAttnProcessor2_0:
|
2837
|
+
r"""
|
2838
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
2839
|
+
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
|
2840
|
+
variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
|
2841
|
+
"""
|
2842
|
+
|
2843
|
+
def __init__(self):
|
2844
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2845
|
+
raise ImportError(
|
2846
|
+
"PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2847
|
+
)
|
2848
|
+
|
2849
|
+
def __call__(
|
2850
|
+
self,
|
2851
|
+
attn: Attention,
|
2852
|
+
hidden_states: torch.Tensor,
|
2853
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2854
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2855
|
+
temb: Optional[torch.Tensor] = None,
|
2856
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2857
|
+
) -> torch.Tensor:
|
2858
|
+
from .embeddings import apply_rotary_emb
|
2859
|
+
|
2860
|
+
residual = hidden_states
|
2861
|
+
if attn.spatial_norm is not None:
|
2862
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2863
|
+
|
2864
|
+
input_ndim = hidden_states.ndim
|
2865
|
+
|
2866
|
+
if input_ndim == 4:
|
2867
|
+
batch_size, channel, height, width = hidden_states.shape
|
2868
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2869
|
+
|
2870
|
+
# chunk
|
2871
|
+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
2872
|
+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
2873
|
+
|
2874
|
+
# 1. Original Path
|
2875
|
+
batch_size, sequence_length, _ = (
|
2876
|
+
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2877
|
+
)
|
2878
|
+
|
2879
|
+
if attention_mask is not None:
|
2880
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2881
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2882
|
+
# (batch, heads, source_length, target_length)
|
2883
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2884
|
+
|
2885
|
+
if attn.group_norm is not None:
|
2886
|
+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
2887
|
+
|
2888
|
+
query = attn.to_q(hidden_states_org)
|
2889
|
+
|
2890
|
+
if encoder_hidden_states is None:
|
2891
|
+
encoder_hidden_states = hidden_states_org
|
2892
|
+
elif attn.norm_cross:
|
2893
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2894
|
+
|
2895
|
+
key = attn.to_k(encoder_hidden_states)
|
2896
|
+
value = attn.to_v(encoder_hidden_states)
|
2897
|
+
|
2898
|
+
inner_dim = key.shape[-1]
|
2899
|
+
head_dim = inner_dim // attn.heads
|
2900
|
+
|
2901
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2902
|
+
|
2903
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2904
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2905
|
+
|
2906
|
+
if attn.norm_q is not None:
|
2907
|
+
query = attn.norm_q(query)
|
2908
|
+
if attn.norm_k is not None:
|
2909
|
+
key = attn.norm_k(key)
|
2910
|
+
|
2911
|
+
# Apply RoPE if needed
|
2912
|
+
if image_rotary_emb is not None:
|
2913
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2914
|
+
if not attn.is_cross_attention:
|
2915
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2916
|
+
|
2917
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2918
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2919
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
2920
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2921
|
+
)
|
2922
|
+
|
2923
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2924
|
+
hidden_states_org = hidden_states_org.to(query.dtype)
|
2925
|
+
|
2926
|
+
# linear proj
|
2927
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
2928
|
+
# dropout
|
2929
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
2930
|
+
|
2931
|
+
if input_ndim == 4:
|
2932
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2933
|
+
|
2934
|
+
# 2. Perturbed Path
|
2935
|
+
if attn.group_norm is not None:
|
2936
|
+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
2937
|
+
|
2938
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb)
|
2939
|
+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
2940
|
+
|
2941
|
+
# linear proj
|
2942
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
2943
|
+
# dropout
|
2944
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
2945
|
+
|
2946
|
+
if input_ndim == 4:
|
2947
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2948
|
+
|
2949
|
+
# cat
|
2950
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
2951
|
+
|
2952
|
+
if attn.residual_connection:
|
2953
|
+
hidden_states = hidden_states + residual
|
2954
|
+
|
2955
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2956
|
+
|
2957
|
+
return hidden_states
|
2958
|
+
|
2959
|
+
|
2960
|
+
class LuminaAttnProcessor2_0:
|
2961
|
+
r"""
|
2962
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
2963
|
+
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
2964
|
+
"""
|
2965
|
+
|
2966
|
+
def __init__(self):
|
2967
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2968
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2969
|
+
|
2970
|
+
def __call__(
|
2971
|
+
self,
|
2972
|
+
attn: Attention,
|
2973
|
+
hidden_states: torch.Tensor,
|
2974
|
+
encoder_hidden_states: torch.Tensor,
|
2975
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2976
|
+
query_rotary_emb: Optional[torch.Tensor] = None,
|
2977
|
+
key_rotary_emb: Optional[torch.Tensor] = None,
|
2978
|
+
base_sequence_length: Optional[int] = None,
|
2979
|
+
) -> torch.Tensor:
|
2980
|
+
from .embeddings import apply_rotary_emb
|
2981
|
+
|
2982
|
+
input_ndim = hidden_states.ndim
|
2983
|
+
|
2984
|
+
if input_ndim == 4:
|
2985
|
+
batch_size, channel, height, width = hidden_states.shape
|
2986
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2987
|
+
|
2988
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
2989
|
+
|
2990
|
+
# Get Query-Key-Value Pair
|
2991
|
+
query = attn.to_q(hidden_states)
|
2992
|
+
key = attn.to_k(encoder_hidden_states)
|
2993
|
+
value = attn.to_v(encoder_hidden_states)
|
2994
|
+
|
2995
|
+
query_dim = query.shape[-1]
|
2996
|
+
inner_dim = key.shape[-1]
|
2997
|
+
head_dim = query_dim // attn.heads
|
2998
|
+
dtype = query.dtype
|
2999
|
+
|
3000
|
+
# Get key-value heads
|
3001
|
+
kv_heads = inner_dim // head_dim
|
3002
|
+
|
3003
|
+
# Apply Query-Key Norm if needed
|
3004
|
+
if attn.norm_q is not None:
|
3005
|
+
query = attn.norm_q(query)
|
3006
|
+
if attn.norm_k is not None:
|
3007
|
+
key = attn.norm_k(key)
|
3008
|
+
|
3009
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
3010
|
+
|
3011
|
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
3012
|
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
3013
|
+
|
3014
|
+
# Apply RoPE if needed
|
3015
|
+
if query_rotary_emb is not None:
|
3016
|
+
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
|
3017
|
+
if key_rotary_emb is not None:
|
3018
|
+
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
|
3019
|
+
|
3020
|
+
query, key = query.to(dtype), key.to(dtype)
|
3021
|
+
|
3022
|
+
# Apply proportional attention if true
|
3023
|
+
if key_rotary_emb is None:
|
3024
|
+
softmax_scale = None
|
3025
|
+
else:
|
3026
|
+
if base_sequence_length is not None:
|
3027
|
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
3028
|
+
else:
|
3029
|
+
softmax_scale = attn.scale
|
1693
3030
|
|
1694
|
-
|
1695
|
-
|
3031
|
+
# perform Grouped-qurey Attention (GQA)
|
3032
|
+
n_rep = attn.heads // kv_heads
|
3033
|
+
if n_rep >= 1:
|
3034
|
+
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
3035
|
+
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
1696
3036
|
|
1697
|
-
|
1698
|
-
|
3037
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
3038
|
+
# (batch, heads, source_length, target_length)
|
3039
|
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
3040
|
+
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
1699
3041
|
|
1700
|
-
|
3042
|
+
query = query.transpose(1, 2)
|
3043
|
+
key = key.transpose(1, 2)
|
3044
|
+
value = value.transpose(1, 2)
|
3045
|
+
|
3046
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
3047
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3048
|
+
hidden_states = F.scaled_dot_product_attention(
|
3049
|
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
3050
|
+
)
|
3051
|
+
hidden_states = hidden_states.transpose(1, 2).to(dtype)
|
1701
3052
|
|
1702
3053
|
return hidden_states
|
1703
3054
|
|
@@ -1778,6 +3129,11 @@ class FusedAttnProcessor2_0:
|
|
1778
3129
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1779
3130
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1780
3131
|
|
3132
|
+
if attn.norm_q is not None:
|
3133
|
+
query = attn.norm_q(query)
|
3134
|
+
if attn.norm_k is not None:
|
3135
|
+
key = attn.norm_k(key)
|
3136
|
+
|
1781
3137
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1782
3138
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
1783
3139
|
hidden_states = F.scaled_dot_product_attention(
|
@@ -2088,7 +3444,7 @@ class SlicedAttnProcessor:
|
|
2088
3444
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
2089
3445
|
)
|
2090
3446
|
|
2091
|
-
for i in range(batch_size_attention // self.slice_size):
|
3447
|
+
for i in range((batch_size_attention - 1) // self.slice_size + 1):
|
2092
3448
|
start_idx = i * self.slice_size
|
2093
3449
|
end_idx = (i + 1) * self.slice_size
|
2094
3450
|
|
@@ -2185,7 +3541,7 @@ class SlicedAttnAddedKVProcessor:
|
|
2185
3541
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
2186
3542
|
)
|
2187
3543
|
|
2188
|
-
for i in range(batch_size_attention // self.slice_size):
|
3544
|
+
for i in range((batch_size_attention - 1) // self.slice_size + 1):
|
2189
3545
|
start_idx = i * self.slice_size
|
2190
3546
|
end_idx = (i + 1) * self.slice_size
|
2191
3547
|
|
@@ -2241,264 +3597,6 @@ class SpatialNorm(nn.Module):
|
|
2241
3597
|
return new_f
|
2242
3598
|
|
2243
3599
|
|
2244
|
-
class LoRAAttnProcessor(nn.Module):
|
2245
|
-
def __init__(
|
2246
|
-
self,
|
2247
|
-
hidden_size: int,
|
2248
|
-
cross_attention_dim: Optional[int] = None,
|
2249
|
-
rank: int = 4,
|
2250
|
-
network_alpha: Optional[int] = None,
|
2251
|
-
**kwargs,
|
2252
|
-
):
|
2253
|
-
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
2254
|
-
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
|
2255
|
-
|
2256
|
-
super().__init__()
|
2257
|
-
|
2258
|
-
self.hidden_size = hidden_size
|
2259
|
-
self.cross_attention_dim = cross_attention_dim
|
2260
|
-
self.rank = rank
|
2261
|
-
|
2262
|
-
q_rank = kwargs.pop("q_rank", None)
|
2263
|
-
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
2264
|
-
q_rank = q_rank if q_rank is not None else rank
|
2265
|
-
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
2266
|
-
|
2267
|
-
v_rank = kwargs.pop("v_rank", None)
|
2268
|
-
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
2269
|
-
v_rank = v_rank if v_rank is not None else rank
|
2270
|
-
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
2271
|
-
|
2272
|
-
out_rank = kwargs.pop("out_rank", None)
|
2273
|
-
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
2274
|
-
out_rank = out_rank if out_rank is not None else rank
|
2275
|
-
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
2276
|
-
|
2277
|
-
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
2278
|
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
2279
|
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
2280
|
-
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
2281
|
-
|
2282
|
-
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2283
|
-
self_cls_name = self.__class__.__name__
|
2284
|
-
deprecate(
|
2285
|
-
self_cls_name,
|
2286
|
-
"0.26.0",
|
2287
|
-
(
|
2288
|
-
f"Make sure use {self_cls_name[4:]} instead by setting"
|
2289
|
-
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
2290
|
-
" `LoraLoaderMixin.load_lora_weights`"
|
2291
|
-
),
|
2292
|
-
)
|
2293
|
-
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
2294
|
-
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
2295
|
-
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
2296
|
-
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
2297
|
-
|
2298
|
-
attn._modules.pop("processor")
|
2299
|
-
attn.processor = AttnProcessor()
|
2300
|
-
return attn.processor(attn, hidden_states, **kwargs)
|
2301
|
-
|
2302
|
-
|
2303
|
-
class LoRAAttnProcessor2_0(nn.Module):
|
2304
|
-
def __init__(
|
2305
|
-
self,
|
2306
|
-
hidden_size: int,
|
2307
|
-
cross_attention_dim: Optional[int] = None,
|
2308
|
-
rank: int = 4,
|
2309
|
-
network_alpha: Optional[int] = None,
|
2310
|
-
**kwargs,
|
2311
|
-
):
|
2312
|
-
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
|
2313
|
-
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
|
2314
|
-
|
2315
|
-
super().__init__()
|
2316
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
2317
|
-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
2318
|
-
|
2319
|
-
self.hidden_size = hidden_size
|
2320
|
-
self.cross_attention_dim = cross_attention_dim
|
2321
|
-
self.rank = rank
|
2322
|
-
|
2323
|
-
q_rank = kwargs.pop("q_rank", None)
|
2324
|
-
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
2325
|
-
q_rank = q_rank if q_rank is not None else rank
|
2326
|
-
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
2327
|
-
|
2328
|
-
v_rank = kwargs.pop("v_rank", None)
|
2329
|
-
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
2330
|
-
v_rank = v_rank if v_rank is not None else rank
|
2331
|
-
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
2332
|
-
|
2333
|
-
out_rank = kwargs.pop("out_rank", None)
|
2334
|
-
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
2335
|
-
out_rank = out_rank if out_rank is not None else rank
|
2336
|
-
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
2337
|
-
|
2338
|
-
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
2339
|
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
2340
|
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
2341
|
-
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
2342
|
-
|
2343
|
-
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2344
|
-
self_cls_name = self.__class__.__name__
|
2345
|
-
deprecate(
|
2346
|
-
self_cls_name,
|
2347
|
-
"0.26.0",
|
2348
|
-
(
|
2349
|
-
f"Make sure use {self_cls_name[4:]} instead by setting"
|
2350
|
-
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
2351
|
-
" `LoraLoaderMixin.load_lora_weights`"
|
2352
|
-
),
|
2353
|
-
)
|
2354
|
-
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
2355
|
-
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
2356
|
-
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
2357
|
-
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
2358
|
-
|
2359
|
-
attn._modules.pop("processor")
|
2360
|
-
attn.processor = AttnProcessor2_0()
|
2361
|
-
return attn.processor(attn, hidden_states, **kwargs)
|
2362
|
-
|
2363
|
-
|
2364
|
-
class LoRAXFormersAttnProcessor(nn.Module):
|
2365
|
-
r"""
|
2366
|
-
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
2367
|
-
|
2368
|
-
Args:
|
2369
|
-
hidden_size (`int`, *optional*):
|
2370
|
-
The hidden size of the attention layer.
|
2371
|
-
cross_attention_dim (`int`, *optional*):
|
2372
|
-
The number of channels in the `encoder_hidden_states`.
|
2373
|
-
rank (`int`, defaults to 4):
|
2374
|
-
The dimension of the LoRA update matrices.
|
2375
|
-
attention_op (`Callable`, *optional*, defaults to `None`):
|
2376
|
-
The base
|
2377
|
-
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
2378
|
-
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
2379
|
-
operator.
|
2380
|
-
network_alpha (`int`, *optional*):
|
2381
|
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
2382
|
-
kwargs (`dict`):
|
2383
|
-
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
2384
|
-
"""
|
2385
|
-
|
2386
|
-
def __init__(
|
2387
|
-
self,
|
2388
|
-
hidden_size: int,
|
2389
|
-
cross_attention_dim: int,
|
2390
|
-
rank: int = 4,
|
2391
|
-
attention_op: Optional[Callable] = None,
|
2392
|
-
network_alpha: Optional[int] = None,
|
2393
|
-
**kwargs,
|
2394
|
-
):
|
2395
|
-
super().__init__()
|
2396
|
-
|
2397
|
-
self.hidden_size = hidden_size
|
2398
|
-
self.cross_attention_dim = cross_attention_dim
|
2399
|
-
self.rank = rank
|
2400
|
-
self.attention_op = attention_op
|
2401
|
-
|
2402
|
-
q_rank = kwargs.pop("q_rank", None)
|
2403
|
-
q_hidden_size = kwargs.pop("q_hidden_size", None)
|
2404
|
-
q_rank = q_rank if q_rank is not None else rank
|
2405
|
-
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
|
2406
|
-
|
2407
|
-
v_rank = kwargs.pop("v_rank", None)
|
2408
|
-
v_hidden_size = kwargs.pop("v_hidden_size", None)
|
2409
|
-
v_rank = v_rank if v_rank is not None else rank
|
2410
|
-
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
|
2411
|
-
|
2412
|
-
out_rank = kwargs.pop("out_rank", None)
|
2413
|
-
out_hidden_size = kwargs.pop("out_hidden_size", None)
|
2414
|
-
out_rank = out_rank if out_rank is not None else rank
|
2415
|
-
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
|
2416
|
-
|
2417
|
-
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
|
2418
|
-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
2419
|
-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
|
2420
|
-
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
|
2421
|
-
|
2422
|
-
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2423
|
-
self_cls_name = self.__class__.__name__
|
2424
|
-
deprecate(
|
2425
|
-
self_cls_name,
|
2426
|
-
"0.26.0",
|
2427
|
-
(
|
2428
|
-
f"Make sure use {self_cls_name[4:]} instead by setting"
|
2429
|
-
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
2430
|
-
" `LoraLoaderMixin.load_lora_weights`"
|
2431
|
-
),
|
2432
|
-
)
|
2433
|
-
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
2434
|
-
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
2435
|
-
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
2436
|
-
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
2437
|
-
|
2438
|
-
attn._modules.pop("processor")
|
2439
|
-
attn.processor = XFormersAttnProcessor()
|
2440
|
-
return attn.processor(attn, hidden_states, **kwargs)
|
2441
|
-
|
2442
|
-
|
2443
|
-
class LoRAAttnAddedKVProcessor(nn.Module):
|
2444
|
-
r"""
|
2445
|
-
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
2446
|
-
encoder.
|
2447
|
-
|
2448
|
-
Args:
|
2449
|
-
hidden_size (`int`, *optional*):
|
2450
|
-
The hidden size of the attention layer.
|
2451
|
-
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
2452
|
-
The number of channels in the `encoder_hidden_states`.
|
2453
|
-
rank (`int`, defaults to 4):
|
2454
|
-
The dimension of the LoRA update matrices.
|
2455
|
-
network_alpha (`int`, *optional*):
|
2456
|
-
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
2457
|
-
kwargs (`dict`):
|
2458
|
-
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
|
2459
|
-
"""
|
2460
|
-
|
2461
|
-
def __init__(
|
2462
|
-
self,
|
2463
|
-
hidden_size: int,
|
2464
|
-
cross_attention_dim: Optional[int] = None,
|
2465
|
-
rank: int = 4,
|
2466
|
-
network_alpha: Optional[int] = None,
|
2467
|
-
):
|
2468
|
-
super().__init__()
|
2469
|
-
|
2470
|
-
self.hidden_size = hidden_size
|
2471
|
-
self.cross_attention_dim = cross_attention_dim
|
2472
|
-
self.rank = rank
|
2473
|
-
|
2474
|
-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2475
|
-
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
2476
|
-
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
2477
|
-
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2478
|
-
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2479
|
-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
2480
|
-
|
2481
|
-
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
|
2482
|
-
self_cls_name = self.__class__.__name__
|
2483
|
-
deprecate(
|
2484
|
-
self_cls_name,
|
2485
|
-
"0.26.0",
|
2486
|
-
(
|
2487
|
-
f"Make sure use {self_cls_name[4:]} instead by setting"
|
2488
|
-
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
2489
|
-
" `LoraLoaderMixin.load_lora_weights`"
|
2490
|
-
),
|
2491
|
-
)
|
2492
|
-
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
|
2493
|
-
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
|
2494
|
-
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
|
2495
|
-
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
|
2496
|
-
|
2497
|
-
attn._modules.pop("processor")
|
2498
|
-
attn.processor = AttnAddedKVProcessor()
|
2499
|
-
return attn.processor(attn, hidden_states, **kwargs)
|
2500
|
-
|
2501
|
-
|
2502
3600
|
class IPAdapterAttnProcessor(nn.Module):
|
2503
3601
|
r"""
|
2504
3602
|
Attention processor for Multiple IP-Adapters.
|
@@ -2927,19 +4025,233 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
2927
4025
|
return hidden_states
|
2928
4026
|
|
2929
4027
|
|
2930
|
-
|
2931
|
-
|
2932
|
-
|
2933
|
-
|
2934
|
-
|
2935
|
-
|
4028
|
+
class PAGIdentitySelfAttnProcessor2_0:
|
4029
|
+
r"""
|
4030
|
+
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
4031
|
+
PAG reference: https://arxiv.org/abs/2403.17377
|
4032
|
+
"""
|
4033
|
+
|
4034
|
+
def __init__(self):
|
4035
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
4036
|
+
raise ImportError(
|
4037
|
+
"PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
4038
|
+
)
|
4039
|
+
|
4040
|
+
def __call__(
|
4041
|
+
self,
|
4042
|
+
attn: Attention,
|
4043
|
+
hidden_states: torch.FloatTensor,
|
4044
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
4045
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
4046
|
+
temb: Optional[torch.FloatTensor] = None,
|
4047
|
+
) -> torch.Tensor:
|
4048
|
+
residual = hidden_states
|
4049
|
+
if attn.spatial_norm is not None:
|
4050
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
4051
|
+
|
4052
|
+
input_ndim = hidden_states.ndim
|
4053
|
+
if input_ndim == 4:
|
4054
|
+
batch_size, channel, height, width = hidden_states.shape
|
4055
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
4056
|
+
|
4057
|
+
# chunk
|
4058
|
+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
4059
|
+
|
4060
|
+
# original path
|
4061
|
+
batch_size, sequence_length, _ = hidden_states_org.shape
|
4062
|
+
|
4063
|
+
if attention_mask is not None:
|
4064
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
4065
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
4066
|
+
# (batch, heads, source_length, target_length)
|
4067
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
4068
|
+
|
4069
|
+
if attn.group_norm is not None:
|
4070
|
+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
4071
|
+
|
4072
|
+
query = attn.to_q(hidden_states_org)
|
4073
|
+
key = attn.to_k(hidden_states_org)
|
4074
|
+
value = attn.to_v(hidden_states_org)
|
4075
|
+
|
4076
|
+
inner_dim = key.shape[-1]
|
4077
|
+
head_dim = inner_dim // attn.heads
|
4078
|
+
|
4079
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4080
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4081
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4082
|
+
|
4083
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
4084
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
4085
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
4086
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
4087
|
+
)
|
4088
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
4089
|
+
hidden_states_org = hidden_states_org.to(query.dtype)
|
4090
|
+
|
4091
|
+
# linear proj
|
4092
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
4093
|
+
# dropout
|
4094
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
4095
|
+
|
4096
|
+
if input_ndim == 4:
|
4097
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
4098
|
+
|
4099
|
+
# perturbed path (identity attention)
|
4100
|
+
batch_size, sequence_length, _ = hidden_states_ptb.shape
|
4101
|
+
|
4102
|
+
if attn.group_norm is not None:
|
4103
|
+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
4104
|
+
|
4105
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb)
|
4106
|
+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
4107
|
+
|
4108
|
+
# linear proj
|
4109
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
4110
|
+
# dropout
|
4111
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
4112
|
+
|
4113
|
+
if input_ndim == 4:
|
4114
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
4115
|
+
|
4116
|
+
# cat
|
4117
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
4118
|
+
|
4119
|
+
if attn.residual_connection:
|
4120
|
+
hidden_states = hidden_states + residual
|
4121
|
+
|
4122
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
4123
|
+
|
4124
|
+
return hidden_states
|
4125
|
+
|
4126
|
+
|
4127
|
+
class PAGCFGIdentitySelfAttnProcessor2_0:
|
4128
|
+
r"""
|
4129
|
+
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
4130
|
+
PAG reference: https://arxiv.org/abs/2403.17377
|
4131
|
+
"""
|
4132
|
+
|
4133
|
+
def __init__(self):
|
4134
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
4135
|
+
raise ImportError(
|
4136
|
+
"PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
4137
|
+
)
|
4138
|
+
|
4139
|
+
def __call__(
|
4140
|
+
self,
|
4141
|
+
attn: Attention,
|
4142
|
+
hidden_states: torch.FloatTensor,
|
4143
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
4144
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
4145
|
+
temb: Optional[torch.FloatTensor] = None,
|
4146
|
+
) -> torch.Tensor:
|
4147
|
+
residual = hidden_states
|
4148
|
+
if attn.spatial_norm is not None:
|
4149
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
4150
|
+
|
4151
|
+
input_ndim = hidden_states.ndim
|
4152
|
+
if input_ndim == 4:
|
4153
|
+
batch_size, channel, height, width = hidden_states.shape
|
4154
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
4155
|
+
|
4156
|
+
# chunk
|
4157
|
+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
4158
|
+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
4159
|
+
|
4160
|
+
# original path
|
4161
|
+
batch_size, sequence_length, _ = hidden_states_org.shape
|
4162
|
+
|
4163
|
+
if attention_mask is not None:
|
4164
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
4165
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
4166
|
+
# (batch, heads, source_length, target_length)
|
4167
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
4168
|
+
|
4169
|
+
if attn.group_norm is not None:
|
4170
|
+
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
4171
|
+
|
4172
|
+
query = attn.to_q(hidden_states_org)
|
4173
|
+
key = attn.to_k(hidden_states_org)
|
4174
|
+
value = attn.to_v(hidden_states_org)
|
4175
|
+
|
4176
|
+
inner_dim = key.shape[-1]
|
4177
|
+
head_dim = inner_dim // attn.heads
|
4178
|
+
|
4179
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4180
|
+
|
4181
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4182
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4183
|
+
|
4184
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
4185
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
4186
|
+
hidden_states_org = F.scaled_dot_product_attention(
|
4187
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
4188
|
+
)
|
4189
|
+
|
4190
|
+
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
4191
|
+
hidden_states_org = hidden_states_org.to(query.dtype)
|
4192
|
+
|
4193
|
+
# linear proj
|
4194
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
4195
|
+
# dropout
|
4196
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
4197
|
+
|
4198
|
+
if input_ndim == 4:
|
4199
|
+
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
4200
|
+
|
4201
|
+
# perturbed path (identity attention)
|
4202
|
+
batch_size, sequence_length, _ = hidden_states_ptb.shape
|
4203
|
+
|
4204
|
+
if attn.group_norm is not None:
|
4205
|
+
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
4206
|
+
|
4207
|
+
value = attn.to_v(hidden_states_ptb)
|
4208
|
+
hidden_states_ptb = value
|
4209
|
+
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
4210
|
+
|
4211
|
+
# linear proj
|
4212
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
4213
|
+
# dropout
|
4214
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
4215
|
+
|
4216
|
+
if input_ndim == 4:
|
4217
|
+
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
4218
|
+
|
4219
|
+
# cat
|
4220
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
4221
|
+
|
4222
|
+
if attn.residual_connection:
|
4223
|
+
hidden_states = hidden_states + residual
|
4224
|
+
|
4225
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
4226
|
+
|
4227
|
+
return hidden_states
|
4228
|
+
|
4229
|
+
|
4230
|
+
class LoRAAttnProcessor:
|
4231
|
+
def __init__(self):
|
4232
|
+
pass
|
4233
|
+
|
4234
|
+
|
4235
|
+
class LoRAAttnProcessor2_0:
|
4236
|
+
def __init__(self):
|
4237
|
+
pass
|
4238
|
+
|
4239
|
+
|
4240
|
+
class LoRAXFormersAttnProcessor:
|
4241
|
+
def __init__(self):
|
4242
|
+
pass
|
4243
|
+
|
4244
|
+
|
4245
|
+
class LoRAAttnAddedKVProcessor:
|
4246
|
+
def __init__(self):
|
4247
|
+
pass
|
4248
|
+
|
2936
4249
|
|
2937
4250
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
2938
4251
|
AttnAddedKVProcessor,
|
2939
4252
|
SlicedAttnAddedKVProcessor,
|
2940
4253
|
AttnAddedKVProcessor2_0,
|
2941
4254
|
XFormersAttnAddedKVProcessor,
|
2942
|
-
LoRAAttnAddedKVProcessor,
|
2943
4255
|
)
|
2944
4256
|
|
2945
4257
|
CROSS_ATTENTION_PROCESSORS = (
|
@@ -2947,9 +4259,6 @@ CROSS_ATTENTION_PROCESSORS = (
|
|
2947
4259
|
AttnProcessor2_0,
|
2948
4260
|
XFormersAttnProcessor,
|
2949
4261
|
SlicedAttnProcessor,
|
2950
|
-
LoRAAttnProcessor,
|
2951
|
-
LoRAAttnProcessor2_0,
|
2952
|
-
LoRAXFormersAttnProcessor,
|
2953
4262
|
IPAdapterAttnProcessor,
|
2954
4263
|
IPAdapterAttnProcessor2_0,
|
2955
4264
|
)
|
@@ -2967,9 +4276,8 @@ AttentionProcessor = Union[
|
|
2967
4276
|
CustomDiffusionAttnProcessor,
|
2968
4277
|
CustomDiffusionXFormersAttnProcessor,
|
2969
4278
|
CustomDiffusionAttnProcessor2_0,
|
2970
|
-
|
2971
|
-
|
2972
|
-
|
2973
|
-
|
2974
|
-
LoRAAttnAddedKVProcessor,
|
4279
|
+
PAGCFGIdentitySelfAttnProcessor2_0,
|
4280
|
+
PAGIdentitySelfAttnProcessor2_0,
|
4281
|
+
PAGCFGHunyuanAttnProcessor2_0,
|
4282
|
+
PAGHunyuanAttnProcessor2_0,
|
2975
4283
|
]
|