diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,8 @@ import torch.nn.functional as F
|
|
20
20
|
from torch import nn
|
21
21
|
|
22
22
|
from ..image_processor import IPAdapterMaskProcessor
|
23
|
-
from ..utils import deprecate, logging
|
24
|
-
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
23
|
+
from ..utils import deprecate, is_torch_xla_available, logging
|
24
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
|
25
25
|
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
26
26
|
|
27
27
|
|
@@ -36,6 +36,15 @@ if is_xformers_available():
|
|
36
36
|
else:
|
37
37
|
xformers = None
|
38
38
|
|
39
|
+
if is_torch_xla_available():
|
40
|
+
# flash attention pallas kernel is introduced in the torch_xla 2.3 release.
|
41
|
+
if is_torch_xla_version(">", "2.2"):
|
42
|
+
from torch_xla.experimental.custom_kernel import flash_attention
|
43
|
+
from torch_xla.runtime import is_spmd
|
44
|
+
XLA_AVAILABLE = True
|
45
|
+
else:
|
46
|
+
XLA_AVAILABLE = False
|
47
|
+
|
39
48
|
|
40
49
|
@maybe_allow_in_graph
|
41
50
|
class Attention(nn.Module):
|
@@ -120,14 +129,16 @@ class Attention(nn.Module):
|
|
120
129
|
_from_deprecated_attn_block: bool = False,
|
121
130
|
processor: Optional["AttnProcessor"] = None,
|
122
131
|
out_dim: int = None,
|
132
|
+
out_context_dim: int = None,
|
123
133
|
context_pre_only=None,
|
124
134
|
pre_only=False,
|
125
135
|
elementwise_affine: bool = True,
|
136
|
+
is_causal: bool = False,
|
126
137
|
):
|
127
138
|
super().__init__()
|
128
139
|
|
129
140
|
# To prevent circular import.
|
130
|
-
from .normalization import FP32LayerNorm, RMSNorm
|
141
|
+
from .normalization import FP32LayerNorm, LpNorm, RMSNorm
|
131
142
|
|
132
143
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
133
144
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
@@ -142,8 +153,10 @@ class Attention(nn.Module):
|
|
142
153
|
self.dropout = dropout
|
143
154
|
self.fused_projections = False
|
144
155
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
156
|
+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
145
157
|
self.context_pre_only = context_pre_only
|
146
158
|
self.pre_only = pre_only
|
159
|
+
self.is_causal = is_causal
|
147
160
|
|
148
161
|
# we make use of this private variable to know whether this class is loaded
|
149
162
|
# with an deprecated state dict so that we can convert it on the fly
|
@@ -186,12 +199,19 @@ class Attention(nn.Module):
|
|
186
199
|
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
187
200
|
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
188
201
|
elif qk_norm == "layer_norm_across_heads":
|
189
|
-
# Lumina
|
202
|
+
# Lumina applies qk norm across all heads
|
190
203
|
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
191
204
|
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
192
205
|
elif qk_norm == "rms_norm":
|
193
206
|
self.norm_q = RMSNorm(dim_head, eps=eps)
|
194
207
|
self.norm_k = RMSNorm(dim_head, eps=eps)
|
208
|
+
elif qk_norm == "rms_norm_across_heads":
|
209
|
+
# LTX applies qk norm across all heads
|
210
|
+
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
211
|
+
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
212
|
+
elif qk_norm == "l2":
|
213
|
+
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
214
|
+
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
195
215
|
else:
|
196
216
|
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
|
197
217
|
|
@@ -234,14 +254,22 @@ class Attention(nn.Module):
|
|
234
254
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
235
255
|
if self.context_pre_only is not None:
|
236
256
|
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
257
|
+
else:
|
258
|
+
self.add_q_proj = None
|
259
|
+
self.add_k_proj = None
|
260
|
+
self.add_v_proj = None
|
237
261
|
|
238
262
|
if not self.pre_only:
|
239
263
|
self.to_out = nn.ModuleList([])
|
240
264
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
241
265
|
self.to_out.append(nn.Dropout(dropout))
|
266
|
+
else:
|
267
|
+
self.to_out = None
|
242
268
|
|
243
269
|
if self.context_pre_only is not None and not self.context_pre_only:
|
244
|
-
self.to_add_out = nn.Linear(self.inner_dim, self.
|
270
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
271
|
+
else:
|
272
|
+
self.to_add_out = None
|
245
273
|
|
246
274
|
if qk_norm is not None and added_kv_proj_dim is not None:
|
247
275
|
if qk_norm == "fp32_layer_norm":
|
@@ -268,6 +296,33 @@ class Attention(nn.Module):
|
|
268
296
|
)
|
269
297
|
self.set_processor(processor)
|
270
298
|
|
299
|
+
def set_use_xla_flash_attention(
|
300
|
+
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
|
301
|
+
) -> None:
|
302
|
+
r"""
|
303
|
+
Set whether to use xla flash attention from `torch_xla` or not.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
use_xla_flash_attention (`bool`):
|
307
|
+
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
308
|
+
partition_spec (`Tuple[]`, *optional*):
|
309
|
+
Specify the partition specification if using SPMD. Otherwise None.
|
310
|
+
"""
|
311
|
+
if use_xla_flash_attention:
|
312
|
+
if not is_torch_xla_available:
|
313
|
+
raise "torch_xla is not available"
|
314
|
+
elif is_torch_xla_version("<", "2.3"):
|
315
|
+
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
|
316
|
+
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
317
|
+
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
318
|
+
else:
|
319
|
+
processor = XLAFlashAttnProcessor2_0(partition_spec)
|
320
|
+
else:
|
321
|
+
processor = (
|
322
|
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
323
|
+
)
|
324
|
+
self.set_processor(processor)
|
325
|
+
|
271
326
|
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
272
327
|
r"""
|
273
328
|
Set whether to use npu flash attention from `torch_npu` or not.
|
@@ -311,6 +366,17 @@ class Attention(nn.Module):
|
|
311
366
|
XFormersAttnAddedKVProcessor,
|
312
367
|
),
|
313
368
|
)
|
369
|
+
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
370
|
+
self.processor,
|
371
|
+
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
372
|
+
)
|
373
|
+
is_joint_processor = hasattr(self, "processor") and isinstance(
|
374
|
+
self.processor,
|
375
|
+
(
|
376
|
+
JointAttnProcessor2_0,
|
377
|
+
XFormersJointAttnProcessor,
|
378
|
+
),
|
379
|
+
)
|
314
380
|
|
315
381
|
if use_memory_efficient_attention_xformers:
|
316
382
|
if is_added_kv_processor and is_custom_diffusion:
|
@@ -361,6 +427,21 @@ class Attention(nn.Module):
|
|
361
427
|
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
362
428
|
)
|
363
429
|
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
430
|
+
elif is_ip_adapter:
|
431
|
+
processor = IPAdapterXFormersAttnProcessor(
|
432
|
+
hidden_size=self.processor.hidden_size,
|
433
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
434
|
+
num_tokens=self.processor.num_tokens,
|
435
|
+
scale=self.processor.scale,
|
436
|
+
attention_op=attention_op,
|
437
|
+
)
|
438
|
+
processor.load_state_dict(self.processor.state_dict())
|
439
|
+
if hasattr(self.processor, "to_k_ip"):
|
440
|
+
processor.to(
|
441
|
+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
442
|
+
)
|
443
|
+
elif is_joint_processor:
|
444
|
+
processor = XFormersJointAttnProcessor(attention_op=attention_op)
|
364
445
|
else:
|
365
446
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
366
447
|
else:
|
@@ -379,6 +460,18 @@ class Attention(nn.Module):
|
|
379
460
|
processor.load_state_dict(self.processor.state_dict())
|
380
461
|
if hasattr(self.processor, "to_k_custom_diffusion"):
|
381
462
|
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
463
|
+
elif is_ip_adapter:
|
464
|
+
processor = IPAdapterAttnProcessor2_0(
|
465
|
+
hidden_size=self.processor.hidden_size,
|
466
|
+
cross_attention_dim=self.processor.cross_attention_dim,
|
467
|
+
num_tokens=self.processor.num_tokens,
|
468
|
+
scale=self.processor.scale,
|
469
|
+
)
|
470
|
+
processor.load_state_dict(self.processor.state_dict())
|
471
|
+
if hasattr(self.processor, "to_k_ip"):
|
472
|
+
processor.to(
|
473
|
+
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
474
|
+
)
|
382
475
|
else:
|
383
476
|
# set attention processor
|
384
477
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
@@ -482,7 +575,7 @@ class Attention(nn.Module):
|
|
482
575
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
483
576
|
|
484
577
|
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
485
|
-
quiet_attn_parameters = {"ip_adapter_masks"}
|
578
|
+
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
486
579
|
unused_kwargs = [
|
487
580
|
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
488
581
|
]
|
@@ -697,7 +790,11 @@ class Attention(nn.Module):
|
|
697
790
|
self.to_kv.bias.copy_(concatenated_bias)
|
698
791
|
|
699
792
|
# handle added projections for SD3 and others.
|
700
|
-
if
|
793
|
+
if (
|
794
|
+
getattr(self, "add_q_proj", None) is not None
|
795
|
+
and getattr(self, "add_k_proj", None) is not None
|
796
|
+
and getattr(self, "add_v_proj", None) is not None
|
797
|
+
):
|
701
798
|
concatenated_weights = torch.cat(
|
702
799
|
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
703
800
|
)
|
@@ -717,6 +814,269 @@ class Attention(nn.Module):
|
|
717
814
|
self.fused_projections = fuse
|
718
815
|
|
719
816
|
|
817
|
+
class SanaMultiscaleAttentionProjection(nn.Module):
|
818
|
+
def __init__(
|
819
|
+
self,
|
820
|
+
in_channels: int,
|
821
|
+
num_attention_heads: int,
|
822
|
+
kernel_size: int,
|
823
|
+
) -> None:
|
824
|
+
super().__init__()
|
825
|
+
|
826
|
+
channels = 3 * in_channels
|
827
|
+
self.proj_in = nn.Conv2d(
|
828
|
+
channels,
|
829
|
+
channels,
|
830
|
+
kernel_size,
|
831
|
+
padding=kernel_size // 2,
|
832
|
+
groups=channels,
|
833
|
+
bias=False,
|
834
|
+
)
|
835
|
+
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
836
|
+
|
837
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
838
|
+
hidden_states = self.proj_in(hidden_states)
|
839
|
+
hidden_states = self.proj_out(hidden_states)
|
840
|
+
return hidden_states
|
841
|
+
|
842
|
+
|
843
|
+
class SanaMultiscaleLinearAttention(nn.Module):
|
844
|
+
r"""Lightweight multi-scale linear attention"""
|
845
|
+
|
846
|
+
def __init__(
|
847
|
+
self,
|
848
|
+
in_channels: int,
|
849
|
+
out_channels: int,
|
850
|
+
num_attention_heads: Optional[int] = None,
|
851
|
+
attention_head_dim: int = 8,
|
852
|
+
mult: float = 1.0,
|
853
|
+
norm_type: str = "batch_norm",
|
854
|
+
kernel_sizes: Tuple[int, ...] = (5,),
|
855
|
+
eps: float = 1e-15,
|
856
|
+
residual_connection: bool = False,
|
857
|
+
):
|
858
|
+
super().__init__()
|
859
|
+
|
860
|
+
# To prevent circular import
|
861
|
+
from .normalization import get_normalization
|
862
|
+
|
863
|
+
self.eps = eps
|
864
|
+
self.attention_head_dim = attention_head_dim
|
865
|
+
self.norm_type = norm_type
|
866
|
+
self.residual_connection = residual_connection
|
867
|
+
|
868
|
+
num_attention_heads = (
|
869
|
+
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
870
|
+
)
|
871
|
+
inner_dim = num_attention_heads * attention_head_dim
|
872
|
+
|
873
|
+
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
874
|
+
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
875
|
+
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
876
|
+
|
877
|
+
self.to_qkv_multiscale = nn.ModuleList()
|
878
|
+
for kernel_size in kernel_sizes:
|
879
|
+
self.to_qkv_multiscale.append(
|
880
|
+
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
881
|
+
)
|
882
|
+
|
883
|
+
self.nonlinearity = nn.ReLU()
|
884
|
+
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
885
|
+
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
886
|
+
|
887
|
+
self.processor = SanaMultiscaleAttnProcessor2_0()
|
888
|
+
|
889
|
+
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
890
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
891
|
+
scores = torch.matmul(value, key.transpose(-1, -2))
|
892
|
+
hidden_states = torch.matmul(scores, query)
|
893
|
+
|
894
|
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
895
|
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
896
|
+
return hidden_states
|
897
|
+
|
898
|
+
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
899
|
+
scores = torch.matmul(key.transpose(-1, -2), query)
|
900
|
+
scores = scores.to(dtype=torch.float32)
|
901
|
+
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
902
|
+
hidden_states = torch.matmul(value, scores)
|
903
|
+
return hidden_states
|
904
|
+
|
905
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
906
|
+
return self.processor(self, hidden_states)
|
907
|
+
|
908
|
+
|
909
|
+
class MochiAttention(nn.Module):
|
910
|
+
def __init__(
|
911
|
+
self,
|
912
|
+
query_dim: int,
|
913
|
+
added_kv_proj_dim: int,
|
914
|
+
processor: "MochiAttnProcessor2_0",
|
915
|
+
heads: int = 8,
|
916
|
+
dim_head: int = 64,
|
917
|
+
dropout: float = 0.0,
|
918
|
+
bias: bool = False,
|
919
|
+
added_proj_bias: bool = True,
|
920
|
+
out_dim: Optional[int] = None,
|
921
|
+
out_context_dim: Optional[int] = None,
|
922
|
+
out_bias: bool = True,
|
923
|
+
context_pre_only: bool = False,
|
924
|
+
eps: float = 1e-5,
|
925
|
+
):
|
926
|
+
super().__init__()
|
927
|
+
from .normalization import MochiRMSNorm
|
928
|
+
|
929
|
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
930
|
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
931
|
+
self.out_context_dim = out_context_dim if out_context_dim else query_dim
|
932
|
+
self.context_pre_only = context_pre_only
|
933
|
+
|
934
|
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
935
|
+
|
936
|
+
self.norm_q = MochiRMSNorm(dim_head, eps, True)
|
937
|
+
self.norm_k = MochiRMSNorm(dim_head, eps, True)
|
938
|
+
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
|
939
|
+
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
|
940
|
+
|
941
|
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
942
|
+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
943
|
+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
944
|
+
|
945
|
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
946
|
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
947
|
+
if self.context_pre_only is not None:
|
948
|
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
949
|
+
|
950
|
+
self.to_out = nn.ModuleList([])
|
951
|
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
952
|
+
self.to_out.append(nn.Dropout(dropout))
|
953
|
+
|
954
|
+
if not self.context_pre_only:
|
955
|
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
956
|
+
|
957
|
+
self.processor = processor
|
958
|
+
|
959
|
+
def forward(
|
960
|
+
self,
|
961
|
+
hidden_states: torch.Tensor,
|
962
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
963
|
+
attention_mask: Optional[torch.Tensor] = None,
|
964
|
+
**kwargs,
|
965
|
+
):
|
966
|
+
return self.processor(
|
967
|
+
self,
|
968
|
+
hidden_states,
|
969
|
+
encoder_hidden_states=encoder_hidden_states,
|
970
|
+
attention_mask=attention_mask,
|
971
|
+
**kwargs,
|
972
|
+
)
|
973
|
+
|
974
|
+
|
975
|
+
class MochiAttnProcessor2_0:
|
976
|
+
"""Attention processor used in Mochi."""
|
977
|
+
|
978
|
+
def __init__(self):
|
979
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
980
|
+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
981
|
+
|
982
|
+
def __call__(
|
983
|
+
self,
|
984
|
+
attn: "MochiAttention",
|
985
|
+
hidden_states: torch.Tensor,
|
986
|
+
encoder_hidden_states: torch.Tensor,
|
987
|
+
attention_mask: torch.Tensor,
|
988
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
989
|
+
) -> torch.Tensor:
|
990
|
+
query = attn.to_q(hidden_states)
|
991
|
+
key = attn.to_k(hidden_states)
|
992
|
+
value = attn.to_v(hidden_states)
|
993
|
+
|
994
|
+
query = query.unflatten(2, (attn.heads, -1))
|
995
|
+
key = key.unflatten(2, (attn.heads, -1))
|
996
|
+
value = value.unflatten(2, (attn.heads, -1))
|
997
|
+
|
998
|
+
if attn.norm_q is not None:
|
999
|
+
query = attn.norm_q(query)
|
1000
|
+
if attn.norm_k is not None:
|
1001
|
+
key = attn.norm_k(key)
|
1002
|
+
|
1003
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
1004
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
1005
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
1006
|
+
|
1007
|
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
1008
|
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
1009
|
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
1010
|
+
|
1011
|
+
if attn.norm_added_q is not None:
|
1012
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
1013
|
+
if attn.norm_added_k is not None:
|
1014
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
1015
|
+
|
1016
|
+
if image_rotary_emb is not None:
|
1017
|
+
|
1018
|
+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
1019
|
+
x_even = x[..., 0::2].float()
|
1020
|
+
x_odd = x[..., 1::2].float()
|
1021
|
+
|
1022
|
+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
1023
|
+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
1024
|
+
|
1025
|
+
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
1026
|
+
|
1027
|
+
query = apply_rotary_emb(query, *image_rotary_emb)
|
1028
|
+
key = apply_rotary_emb(key, *image_rotary_emb)
|
1029
|
+
|
1030
|
+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
1031
|
+
encoder_query, encoder_key, encoder_value = (
|
1032
|
+
encoder_query.transpose(1, 2),
|
1033
|
+
encoder_key.transpose(1, 2),
|
1034
|
+
encoder_value.transpose(1, 2),
|
1035
|
+
)
|
1036
|
+
|
1037
|
+
sequence_length = query.size(2)
|
1038
|
+
encoder_sequence_length = encoder_query.size(2)
|
1039
|
+
total_length = sequence_length + encoder_sequence_length
|
1040
|
+
|
1041
|
+
batch_size, heads, _, dim = query.shape
|
1042
|
+
attn_outputs = []
|
1043
|
+
for idx in range(batch_size):
|
1044
|
+
mask = attention_mask[idx][None, :]
|
1045
|
+
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
|
1046
|
+
|
1047
|
+
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1048
|
+
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1049
|
+
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
|
1050
|
+
|
1051
|
+
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
|
1052
|
+
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
|
1053
|
+
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
|
1054
|
+
|
1055
|
+
attn_output = F.scaled_dot_product_attention(
|
1056
|
+
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
|
1057
|
+
)
|
1058
|
+
valid_sequence_length = attn_output.size(2)
|
1059
|
+
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
|
1060
|
+
attn_outputs.append(attn_output)
|
1061
|
+
|
1062
|
+
hidden_states = torch.cat(attn_outputs, dim=0)
|
1063
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
1064
|
+
|
1065
|
+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
1066
|
+
(sequence_length, encoder_sequence_length), dim=1
|
1067
|
+
)
|
1068
|
+
|
1069
|
+
# linear proj
|
1070
|
+
hidden_states = attn.to_out[0](hidden_states)
|
1071
|
+
# dropout
|
1072
|
+
hidden_states = attn.to_out[1](hidden_states)
|
1073
|
+
|
1074
|
+
if hasattr(attn, "to_add_out"):
|
1075
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1076
|
+
|
1077
|
+
return hidden_states, encoder_hidden_states
|
1078
|
+
|
1079
|
+
|
720
1080
|
class AttnProcessor:
|
721
1081
|
r"""
|
722
1082
|
Default processor for performing attention-related computations.
|
@@ -1136,6 +1496,7 @@ class PAGJointAttnProcessor2_0:
|
|
1136
1496
|
attn: Attention,
|
1137
1497
|
hidden_states: torch.FloatTensor,
|
1138
1498
|
encoder_hidden_states: torch.FloatTensor = None,
|
1499
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1139
1500
|
) -> torch.FloatTensor:
|
1140
1501
|
residual = hidden_states
|
1141
1502
|
|
@@ -1521,92 +1882,84 @@ class FusedJointAttnProcessor2_0:
|
|
1521
1882
|
return hidden_states, encoder_hidden_states
|
1522
1883
|
|
1523
1884
|
|
1524
|
-
class
|
1525
|
-
"""
|
1885
|
+
class XFormersJointAttnProcessor:
|
1886
|
+
r"""
|
1887
|
+
Processor for implementing memory efficient attention using xFormers.
|
1526
1888
|
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1889
|
+
Args:
|
1890
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1891
|
+
The base
|
1892
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1893
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1894
|
+
operator.
|
1895
|
+
"""
|
1896
|
+
|
1897
|
+
def __init__(self, attention_op: Optional[Callable] = None):
|
1898
|
+
self.attention_op = attention_op
|
1532
1899
|
|
1533
1900
|
def __call__(
|
1534
1901
|
self,
|
1535
1902
|
attn: Attention,
|
1536
1903
|
hidden_states: torch.FloatTensor,
|
1537
1904
|
encoder_hidden_states: torch.FloatTensor = None,
|
1905
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1538
1906
|
*args,
|
1539
1907
|
**kwargs,
|
1540
1908
|
) -> torch.FloatTensor:
|
1541
|
-
|
1909
|
+
residual = hidden_states
|
1542
1910
|
|
1543
1911
|
# `sample` projections.
|
1544
1912
|
query = attn.to_q(hidden_states)
|
1545
1913
|
key = attn.to_k(hidden_states)
|
1546
1914
|
value = attn.to_v(hidden_states)
|
1547
1915
|
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1552
|
-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1553
|
-
|
1554
|
-
# Reshape.
|
1555
|
-
inner_dim = key.shape[-1]
|
1556
|
-
head_dim = inner_dim // attn.heads
|
1557
|
-
query = query.view(batch_size, -1, attn.heads, head_dim)
|
1558
|
-
key = key.view(batch_size, -1, attn.heads, head_dim)
|
1559
|
-
value = value.view(batch_size, -1, attn.heads, head_dim)
|
1916
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
1917
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
1918
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
1560
1919
|
|
1561
|
-
# Apply QK norm.
|
1562
1920
|
if attn.norm_q is not None:
|
1563
1921
|
query = attn.norm_q(query)
|
1564
1922
|
if attn.norm_k is not None:
|
1565
1923
|
key = attn.norm_k(key)
|
1566
1924
|
|
1567
|
-
#
|
1925
|
+
# `context` projections.
|
1568
1926
|
if encoder_hidden_states is not None:
|
1569
|
-
encoder_hidden_states_query_proj =
|
1570
|
-
|
1571
|
-
)
|
1572
|
-
|
1573
|
-
|
1574
|
-
|
1575
|
-
)
|
1927
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
1928
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1929
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1930
|
+
|
1931
|
+
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
|
1932
|
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
|
1933
|
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
|
1576
1934
|
|
1577
1935
|
if attn.norm_added_q is not None:
|
1578
1936
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
1579
1937
|
if attn.norm_added_k is not None:
|
1580
|
-
encoder_hidden_states_key_proj = attn.
|
1581
|
-
|
1582
|
-
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
1583
|
-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1584
|
-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1938
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1585
1939
|
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1940
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
1941
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
1942
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
1589
1943
|
|
1590
|
-
|
1591
|
-
|
1592
|
-
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
1944
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1945
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1593
1946
|
)
|
1594
|
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1595
1947
|
hidden_states = hidden_states.to(query.dtype)
|
1948
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1596
1949
|
|
1597
|
-
# Split the attention outputs.
|
1598
1950
|
if encoder_hidden_states is not None:
|
1951
|
+
# Split the attention outputs.
|
1599
1952
|
hidden_states, encoder_hidden_states = (
|
1600
|
-
hidden_states[:,
|
1601
|
-
hidden_states[:,
|
1953
|
+
hidden_states[:, : residual.shape[1]],
|
1954
|
+
hidden_states[:, residual.shape[1] :],
|
1602
1955
|
)
|
1956
|
+
if not attn.context_pre_only:
|
1957
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1603
1958
|
|
1604
1959
|
# linear proj
|
1605
1960
|
hidden_states = attn.to_out[0](hidden_states)
|
1606
1961
|
# dropout
|
1607
1962
|
hidden_states = attn.to_out[1](hidden_states)
|
1608
|
-
if encoder_hidden_states is not None:
|
1609
|
-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1610
1963
|
|
1611
1964
|
if encoder_hidden_states is not None:
|
1612
1965
|
return hidden_states, encoder_hidden_states
|
@@ -1614,27 +1967,214 @@ class AuraFlowAttnProcessor2_0:
|
|
1614
1967
|
return hidden_states
|
1615
1968
|
|
1616
1969
|
|
1617
|
-
class
|
1618
|
-
"""
|
1970
|
+
class AllegroAttnProcessor2_0:
|
1971
|
+
r"""
|
1972
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
1973
|
+
used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
|
1974
|
+
"""
|
1619
1975
|
|
1620
1976
|
def __init__(self):
|
1621
|
-
if not hasattr(F, "scaled_dot_product_attention")
|
1977
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1622
1978
|
raise ImportError(
|
1623
|
-
"
|
1979
|
+
"AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1624
1980
|
)
|
1625
1981
|
|
1626
1982
|
def __call__(
|
1627
1983
|
self,
|
1628
1984
|
attn: Attention,
|
1629
|
-
hidden_states: torch.
|
1630
|
-
encoder_hidden_states: torch.
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
1636
|
-
|
1637
|
-
|
1985
|
+
hidden_states: torch.Tensor,
|
1986
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1987
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1988
|
+
temb: Optional[torch.Tensor] = None,
|
1989
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
1990
|
+
) -> torch.Tensor:
|
1991
|
+
residual = hidden_states
|
1992
|
+
|
1993
|
+
if attn.spatial_norm is not None:
|
1994
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1995
|
+
|
1996
|
+
input_ndim = hidden_states.ndim
|
1997
|
+
|
1998
|
+
if input_ndim == 4:
|
1999
|
+
batch_size, channel, height, width = hidden_states.shape
|
2000
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
2001
|
+
|
2002
|
+
batch_size, sequence_length, _ = (
|
2003
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2004
|
+
)
|
2005
|
+
|
2006
|
+
if attention_mask is not None:
|
2007
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
2008
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
2009
|
+
# (batch, heads, source_length, target_length)
|
2010
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2011
|
+
|
2012
|
+
if attn.group_norm is not None:
|
2013
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
2014
|
+
|
2015
|
+
query = attn.to_q(hidden_states)
|
2016
|
+
|
2017
|
+
if encoder_hidden_states is None:
|
2018
|
+
encoder_hidden_states = hidden_states
|
2019
|
+
elif attn.norm_cross:
|
2020
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2021
|
+
|
2022
|
+
key = attn.to_k(encoder_hidden_states)
|
2023
|
+
value = attn.to_v(encoder_hidden_states)
|
2024
|
+
|
2025
|
+
inner_dim = key.shape[-1]
|
2026
|
+
head_dim = inner_dim // attn.heads
|
2027
|
+
|
2028
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2029
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2030
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2031
|
+
|
2032
|
+
# Apply RoPE if needed
|
2033
|
+
if image_rotary_emb is not None and not attn.is_cross_attention:
|
2034
|
+
from .embeddings import apply_rotary_emb_allegro
|
2035
|
+
|
2036
|
+
query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
|
2037
|
+
key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
|
2038
|
+
|
2039
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2040
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2041
|
+
hidden_states = F.scaled_dot_product_attention(
|
2042
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2043
|
+
)
|
2044
|
+
|
2045
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2046
|
+
hidden_states = hidden_states.to(query.dtype)
|
2047
|
+
|
2048
|
+
# linear proj
|
2049
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2050
|
+
# dropout
|
2051
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2052
|
+
|
2053
|
+
if input_ndim == 4:
|
2054
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2055
|
+
|
2056
|
+
if attn.residual_connection:
|
2057
|
+
hidden_states = hidden_states + residual
|
2058
|
+
|
2059
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
2060
|
+
|
2061
|
+
return hidden_states
|
2062
|
+
|
2063
|
+
|
2064
|
+
class AuraFlowAttnProcessor2_0:
|
2065
|
+
"""Attention processor used typically in processing Aura Flow."""
|
2066
|
+
|
2067
|
+
def __init__(self):
|
2068
|
+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
2069
|
+
raise ImportError(
|
2070
|
+
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
2071
|
+
)
|
2072
|
+
|
2073
|
+
def __call__(
|
2074
|
+
self,
|
2075
|
+
attn: Attention,
|
2076
|
+
hidden_states: torch.FloatTensor,
|
2077
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2078
|
+
*args,
|
2079
|
+
**kwargs,
|
2080
|
+
) -> torch.FloatTensor:
|
2081
|
+
batch_size = hidden_states.shape[0]
|
2082
|
+
|
2083
|
+
# `sample` projections.
|
2084
|
+
query = attn.to_q(hidden_states)
|
2085
|
+
key = attn.to_k(hidden_states)
|
2086
|
+
value = attn.to_v(hidden_states)
|
2087
|
+
|
2088
|
+
# `context` projections.
|
2089
|
+
if encoder_hidden_states is not None:
|
2090
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2091
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2092
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2093
|
+
|
2094
|
+
# Reshape.
|
2095
|
+
inner_dim = key.shape[-1]
|
2096
|
+
head_dim = inner_dim // attn.heads
|
2097
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
2098
|
+
key = key.view(batch_size, -1, attn.heads, head_dim)
|
2099
|
+
value = value.view(batch_size, -1, attn.heads, head_dim)
|
2100
|
+
|
2101
|
+
# Apply QK norm.
|
2102
|
+
if attn.norm_q is not None:
|
2103
|
+
query = attn.norm_q(query)
|
2104
|
+
if attn.norm_k is not None:
|
2105
|
+
key = attn.norm_k(key)
|
2106
|
+
|
2107
|
+
# Concatenate the projections.
|
2108
|
+
if encoder_hidden_states is not None:
|
2109
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2110
|
+
batch_size, -1, attn.heads, head_dim
|
2111
|
+
)
|
2112
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
2113
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2114
|
+
batch_size, -1, attn.heads, head_dim
|
2115
|
+
)
|
2116
|
+
|
2117
|
+
if attn.norm_added_q is not None:
|
2118
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2119
|
+
if attn.norm_added_k is not None:
|
2120
|
+
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
2121
|
+
|
2122
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
2123
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
2124
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
2125
|
+
|
2126
|
+
query = query.transpose(1, 2)
|
2127
|
+
key = key.transpose(1, 2)
|
2128
|
+
value = value.transpose(1, 2)
|
2129
|
+
|
2130
|
+
# Attention.
|
2131
|
+
hidden_states = F.scaled_dot_product_attention(
|
2132
|
+
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
2133
|
+
)
|
2134
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2135
|
+
hidden_states = hidden_states.to(query.dtype)
|
2136
|
+
|
2137
|
+
# Split the attention outputs.
|
2138
|
+
if encoder_hidden_states is not None:
|
2139
|
+
hidden_states, encoder_hidden_states = (
|
2140
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2141
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2142
|
+
)
|
2143
|
+
|
2144
|
+
# linear proj
|
2145
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2146
|
+
# dropout
|
2147
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2148
|
+
if encoder_hidden_states is not None:
|
2149
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2150
|
+
|
2151
|
+
if encoder_hidden_states is not None:
|
2152
|
+
return hidden_states, encoder_hidden_states
|
2153
|
+
else:
|
2154
|
+
return hidden_states
|
2155
|
+
|
2156
|
+
|
2157
|
+
class FusedAuraFlowAttnProcessor2_0:
|
2158
|
+
"""Attention processor used typically in processing Aura Flow with fused projections."""
|
2159
|
+
|
2160
|
+
def __init__(self):
|
2161
|
+
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
2162
|
+
raise ImportError(
|
2163
|
+
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
2164
|
+
)
|
2165
|
+
|
2166
|
+
def __call__(
|
2167
|
+
self,
|
2168
|
+
attn: Attention,
|
2169
|
+
hidden_states: torch.FloatTensor,
|
2170
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2171
|
+
*args,
|
2172
|
+
**kwargs,
|
2173
|
+
) -> torch.FloatTensor:
|
2174
|
+
batch_size = hidden_states.shape[0]
|
2175
|
+
|
2176
|
+
# `sample` projections.
|
2177
|
+
qkv = attn.to_qkv(hidden_states)
|
1638
2178
|
split_size = qkv.shape[-1] // 3
|
1639
2179
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
1640
2180
|
|
@@ -1778,7 +2318,321 @@ class FluxAttnProcessor2_0:
|
|
1778
2318
|
query = apply_rotary_emb(query, image_rotary_emb)
|
1779
2319
|
key = apply_rotary_emb(key, image_rotary_emb)
|
1780
2320
|
|
1781
|
-
hidden_states = F.scaled_dot_product_attention(
|
2321
|
+
hidden_states = F.scaled_dot_product_attention(
|
2322
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2323
|
+
)
|
2324
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2325
|
+
hidden_states = hidden_states.to(query.dtype)
|
2326
|
+
|
2327
|
+
if encoder_hidden_states is not None:
|
2328
|
+
encoder_hidden_states, hidden_states = (
|
2329
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2330
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2331
|
+
)
|
2332
|
+
|
2333
|
+
# linear proj
|
2334
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2335
|
+
# dropout
|
2336
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2337
|
+
|
2338
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2339
|
+
|
2340
|
+
return hidden_states, encoder_hidden_states
|
2341
|
+
else:
|
2342
|
+
return hidden_states
|
2343
|
+
|
2344
|
+
|
2345
|
+
class FluxAttnProcessor2_0_NPU:
|
2346
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2347
|
+
|
2348
|
+
def __init__(self):
|
2349
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2350
|
+
raise ImportError(
|
2351
|
+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
2352
|
+
)
|
2353
|
+
|
2354
|
+
def __call__(
|
2355
|
+
self,
|
2356
|
+
attn: Attention,
|
2357
|
+
hidden_states: torch.FloatTensor,
|
2358
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2359
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2360
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2361
|
+
) -> torch.FloatTensor:
|
2362
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2363
|
+
|
2364
|
+
# `sample` projections.
|
2365
|
+
query = attn.to_q(hidden_states)
|
2366
|
+
key = attn.to_k(hidden_states)
|
2367
|
+
value = attn.to_v(hidden_states)
|
2368
|
+
|
2369
|
+
inner_dim = key.shape[-1]
|
2370
|
+
head_dim = inner_dim // attn.heads
|
2371
|
+
|
2372
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2373
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2374
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2375
|
+
|
2376
|
+
if attn.norm_q is not None:
|
2377
|
+
query = attn.norm_q(query)
|
2378
|
+
if attn.norm_k is not None:
|
2379
|
+
key = attn.norm_k(key)
|
2380
|
+
|
2381
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2382
|
+
if encoder_hidden_states is not None:
|
2383
|
+
# `context` projections.
|
2384
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2385
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2386
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
2387
|
+
|
2388
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2389
|
+
batch_size, -1, attn.heads, head_dim
|
2390
|
+
).transpose(1, 2)
|
2391
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2392
|
+
batch_size, -1, attn.heads, head_dim
|
2393
|
+
).transpose(1, 2)
|
2394
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2395
|
+
batch_size, -1, attn.heads, head_dim
|
2396
|
+
).transpose(1, 2)
|
2397
|
+
|
2398
|
+
if attn.norm_added_q is not None:
|
2399
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2400
|
+
if attn.norm_added_k is not None:
|
2401
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2402
|
+
|
2403
|
+
# attention
|
2404
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2405
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2406
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2407
|
+
|
2408
|
+
if image_rotary_emb is not None:
|
2409
|
+
from .embeddings import apply_rotary_emb
|
2410
|
+
|
2411
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2412
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2413
|
+
|
2414
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
2415
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
2416
|
+
query,
|
2417
|
+
key,
|
2418
|
+
value,
|
2419
|
+
attn.heads,
|
2420
|
+
input_layout="BNSD",
|
2421
|
+
pse=None,
|
2422
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
2423
|
+
pre_tockens=65536,
|
2424
|
+
next_tockens=65536,
|
2425
|
+
keep_prob=1.0,
|
2426
|
+
sync=False,
|
2427
|
+
inner_precise=0,
|
2428
|
+
)[0]
|
2429
|
+
else:
|
2430
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2431
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2432
|
+
hidden_states = hidden_states.to(query.dtype)
|
2433
|
+
|
2434
|
+
if encoder_hidden_states is not None:
|
2435
|
+
encoder_hidden_states, hidden_states = (
|
2436
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2437
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2438
|
+
)
|
2439
|
+
|
2440
|
+
# linear proj
|
2441
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2442
|
+
# dropout
|
2443
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2444
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2445
|
+
|
2446
|
+
return hidden_states, encoder_hidden_states
|
2447
|
+
else:
|
2448
|
+
return hidden_states
|
2449
|
+
|
2450
|
+
|
2451
|
+
class FusedFluxAttnProcessor2_0:
|
2452
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2453
|
+
|
2454
|
+
def __init__(self):
|
2455
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2456
|
+
raise ImportError(
|
2457
|
+
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
2458
|
+
)
|
2459
|
+
|
2460
|
+
def __call__(
|
2461
|
+
self,
|
2462
|
+
attn: Attention,
|
2463
|
+
hidden_states: torch.FloatTensor,
|
2464
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2465
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2466
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2467
|
+
) -> torch.FloatTensor:
|
2468
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2469
|
+
|
2470
|
+
# `sample` projections.
|
2471
|
+
qkv = attn.to_qkv(hidden_states)
|
2472
|
+
split_size = qkv.shape[-1] // 3
|
2473
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2474
|
+
|
2475
|
+
inner_dim = key.shape[-1]
|
2476
|
+
head_dim = inner_dim // attn.heads
|
2477
|
+
|
2478
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2479
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2480
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2481
|
+
|
2482
|
+
if attn.norm_q is not None:
|
2483
|
+
query = attn.norm_q(query)
|
2484
|
+
if attn.norm_k is not None:
|
2485
|
+
key = attn.norm_k(key)
|
2486
|
+
|
2487
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2488
|
+
# `context` projections.
|
2489
|
+
if encoder_hidden_states is not None:
|
2490
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2491
|
+
split_size = encoder_qkv.shape[-1] // 3
|
2492
|
+
(
|
2493
|
+
encoder_hidden_states_query_proj,
|
2494
|
+
encoder_hidden_states_key_proj,
|
2495
|
+
encoder_hidden_states_value_proj,
|
2496
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2497
|
+
|
2498
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2499
|
+
batch_size, -1, attn.heads, head_dim
|
2500
|
+
).transpose(1, 2)
|
2501
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2502
|
+
batch_size, -1, attn.heads, head_dim
|
2503
|
+
).transpose(1, 2)
|
2504
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2505
|
+
batch_size, -1, attn.heads, head_dim
|
2506
|
+
).transpose(1, 2)
|
2507
|
+
|
2508
|
+
if attn.norm_added_q is not None:
|
2509
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2510
|
+
if attn.norm_added_k is not None:
|
2511
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2512
|
+
|
2513
|
+
# attention
|
2514
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2515
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2516
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2517
|
+
|
2518
|
+
if image_rotary_emb is not None:
|
2519
|
+
from .embeddings import apply_rotary_emb
|
2520
|
+
|
2521
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2522
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2523
|
+
|
2524
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2525
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2526
|
+
hidden_states = hidden_states.to(query.dtype)
|
2527
|
+
|
2528
|
+
if encoder_hidden_states is not None:
|
2529
|
+
encoder_hidden_states, hidden_states = (
|
2530
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
2531
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
2532
|
+
)
|
2533
|
+
|
2534
|
+
# linear proj
|
2535
|
+
hidden_states = attn.to_out[0](hidden_states)
|
2536
|
+
# dropout
|
2537
|
+
hidden_states = attn.to_out[1](hidden_states)
|
2538
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
2539
|
+
|
2540
|
+
return hidden_states, encoder_hidden_states
|
2541
|
+
else:
|
2542
|
+
return hidden_states
|
2543
|
+
|
2544
|
+
|
2545
|
+
class FusedFluxAttnProcessor2_0_NPU:
|
2546
|
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
2547
|
+
|
2548
|
+
def __init__(self):
|
2549
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
2550
|
+
raise ImportError(
|
2551
|
+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
2552
|
+
)
|
2553
|
+
|
2554
|
+
def __call__(
|
2555
|
+
self,
|
2556
|
+
attn: Attention,
|
2557
|
+
hidden_states: torch.FloatTensor,
|
2558
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
2559
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
2560
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
2561
|
+
) -> torch.FloatTensor:
|
2562
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2563
|
+
|
2564
|
+
# `sample` projections.
|
2565
|
+
qkv = attn.to_qkv(hidden_states)
|
2566
|
+
split_size = qkv.shape[-1] // 3
|
2567
|
+
query, key, value = torch.split(qkv, split_size, dim=-1)
|
2568
|
+
|
2569
|
+
inner_dim = key.shape[-1]
|
2570
|
+
head_dim = inner_dim // attn.heads
|
2571
|
+
|
2572
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2573
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2574
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2575
|
+
|
2576
|
+
if attn.norm_q is not None:
|
2577
|
+
query = attn.norm_q(query)
|
2578
|
+
if attn.norm_k is not None:
|
2579
|
+
key = attn.norm_k(key)
|
2580
|
+
|
2581
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
2582
|
+
# `context` projections.
|
2583
|
+
if encoder_hidden_states is not None:
|
2584
|
+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
2585
|
+
split_size = encoder_qkv.shape[-1] // 3
|
2586
|
+
(
|
2587
|
+
encoder_hidden_states_query_proj,
|
2588
|
+
encoder_hidden_states_key_proj,
|
2589
|
+
encoder_hidden_states_value_proj,
|
2590
|
+
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2591
|
+
|
2592
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
2593
|
+
batch_size, -1, attn.heads, head_dim
|
2594
|
+
).transpose(1, 2)
|
2595
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
2596
|
+
batch_size, -1, attn.heads, head_dim
|
2597
|
+
).transpose(1, 2)
|
2598
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
2599
|
+
batch_size, -1, attn.heads, head_dim
|
2600
|
+
).transpose(1, 2)
|
2601
|
+
|
2602
|
+
if attn.norm_added_q is not None:
|
2603
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
2604
|
+
if attn.norm_added_k is not None:
|
2605
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
2606
|
+
|
2607
|
+
# attention
|
2608
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
2609
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
2610
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
2611
|
+
|
2612
|
+
if image_rotary_emb is not None:
|
2613
|
+
from .embeddings import apply_rotary_emb
|
2614
|
+
|
2615
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
2616
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
2617
|
+
|
2618
|
+
if query.dtype in (torch.float16, torch.bfloat16):
|
2619
|
+
hidden_states = torch_npu.npu_fusion_attention(
|
2620
|
+
query,
|
2621
|
+
key,
|
2622
|
+
value,
|
2623
|
+
attn.heads,
|
2624
|
+
input_layout="BNSD",
|
2625
|
+
pse=None,
|
2626
|
+
scale=1.0 / math.sqrt(query.shape[-1]),
|
2627
|
+
pre_tockens=65536,
|
2628
|
+
next_tockens=65536,
|
2629
|
+
keep_prob=1.0,
|
2630
|
+
sync=False,
|
2631
|
+
inner_precise=0,
|
2632
|
+
)[0]
|
2633
|
+
else:
|
2634
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2635
|
+
|
1782
2636
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1783
2637
|
hidden_states = hidden_states.to(query.dtype)
|
1784
2638
|
|
@@ -1799,15 +2653,44 @@ class FluxAttnProcessor2_0:
|
|
1799
2653
|
return hidden_states
|
1800
2654
|
|
1801
2655
|
|
1802
|
-
class
|
1803
|
-
"""Attention processor
|
2656
|
+
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
2657
|
+
"""Flux Attention processor for IP-Adapter."""
|
2658
|
+
|
2659
|
+
def __init__(
|
2660
|
+
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
2661
|
+
):
|
2662
|
+
super().__init__()
|
1804
2663
|
|
1805
|
-
def __init__(self):
|
1806
2664
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1807
2665
|
raise ImportError(
|
1808
|
-
"
|
2666
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
1809
2667
|
)
|
1810
2668
|
|
2669
|
+
self.hidden_size = hidden_size
|
2670
|
+
self.cross_attention_dim = cross_attention_dim
|
2671
|
+
|
2672
|
+
if not isinstance(num_tokens, (tuple, list)):
|
2673
|
+
num_tokens = [num_tokens]
|
2674
|
+
|
2675
|
+
if not isinstance(scale, list):
|
2676
|
+
scale = [scale] * len(num_tokens)
|
2677
|
+
if len(scale) != len(num_tokens):
|
2678
|
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
2679
|
+
self.scale = scale
|
2680
|
+
|
2681
|
+
self.to_k_ip = nn.ModuleList(
|
2682
|
+
[
|
2683
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2684
|
+
for _ in range(len(num_tokens))
|
2685
|
+
]
|
2686
|
+
)
|
2687
|
+
self.to_v_ip = nn.ModuleList(
|
2688
|
+
[
|
2689
|
+
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
2690
|
+
for _ in range(len(num_tokens))
|
2691
|
+
]
|
2692
|
+
)
|
2693
|
+
|
1811
2694
|
def __call__(
|
1812
2695
|
self,
|
1813
2696
|
attn: Attention,
|
@@ -1815,36 +2698,34 @@ class FusedFluxAttnProcessor2_0:
|
|
1815
2698
|
encoder_hidden_states: torch.FloatTensor = None,
|
1816
2699
|
attention_mask: Optional[torch.FloatTensor] = None,
|
1817
2700
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
2701
|
+
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
2702
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
1818
2703
|
) -> torch.FloatTensor:
|
1819
2704
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1820
2705
|
|
1821
2706
|
# `sample` projections.
|
1822
|
-
|
1823
|
-
|
1824
|
-
|
2707
|
+
hidden_states_query_proj = attn.to_q(hidden_states)
|
2708
|
+
key = attn.to_k(hidden_states)
|
2709
|
+
value = attn.to_v(hidden_states)
|
1825
2710
|
|
1826
2711
|
inner_dim = key.shape[-1]
|
1827
2712
|
head_dim = inner_dim // attn.heads
|
1828
2713
|
|
1829
|
-
|
2714
|
+
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1830
2715
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1831
2716
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1832
2717
|
|
1833
2718
|
if attn.norm_q is not None:
|
1834
|
-
|
2719
|
+
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
1835
2720
|
if attn.norm_k is not None:
|
1836
2721
|
key = attn.norm_k(key)
|
1837
2722
|
|
1838
2723
|
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
1839
|
-
# `context` projections.
|
1840
2724
|
if encoder_hidden_states is not None:
|
1841
|
-
|
1842
|
-
|
1843
|
-
(
|
1844
|
-
|
1845
|
-
encoder_hidden_states_key_proj,
|
1846
|
-
encoder_hidden_states_value_proj,
|
1847
|
-
) = torch.split(encoder_qkv, split_size, dim=-1)
|
2725
|
+
# `context` projections.
|
2726
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
2727
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
2728
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1848
2729
|
|
1849
2730
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
1850
2731
|
batch_size, -1, attn.heads, head_dim
|
@@ -1862,7 +2743,7 @@ class FusedFluxAttnProcessor2_0:
|
|
1862
2743
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
1863
2744
|
|
1864
2745
|
# attention
|
1865
|
-
query = torch.cat([encoder_hidden_states_query_proj,
|
2746
|
+
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
1866
2747
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
1867
2748
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
1868
2749
|
|
@@ -1888,7 +2769,29 @@ class FusedFluxAttnProcessor2_0:
|
|
1888
2769
|
hidden_states = attn.to_out[1](hidden_states)
|
1889
2770
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
1890
2771
|
|
1891
|
-
|
2772
|
+
# IP-adapter
|
2773
|
+
ip_query = hidden_states_query_proj
|
2774
|
+
ip_attn_output = None
|
2775
|
+
# for ip-adapter
|
2776
|
+
# TODO: support for multiple adapters
|
2777
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2778
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2779
|
+
):
|
2780
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
2781
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
2782
|
+
|
2783
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2784
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2785
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2786
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
2787
|
+
ip_attn_output = F.scaled_dot_product_attention(
|
2788
|
+
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2789
|
+
)
|
2790
|
+
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2791
|
+
ip_attn_output = scale * ip_attn_output
|
2792
|
+
ip_attn_output = ip_attn_output.to(ip_query.dtype)
|
2793
|
+
|
2794
|
+
return hidden_states, encoder_hidden_states, ip_attn_output
|
1892
2795
|
else:
|
1893
2796
|
return hidden_states
|
1894
2797
|
|
@@ -2285,7 +3188,217 @@ class AttnProcessorNPU:
|
|
2285
3188
|
inner_precise=0,
|
2286
3189
|
)[0]
|
2287
3190
|
else:
|
2288
|
-
# TODO: add support for attn.scale when we move to Torch 2.1
|
3191
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3192
|
+
hidden_states = F.scaled_dot_product_attention(
|
3193
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
3194
|
+
)
|
3195
|
+
|
3196
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3197
|
+
hidden_states = hidden_states.to(query.dtype)
|
3198
|
+
|
3199
|
+
# linear proj
|
3200
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3201
|
+
# dropout
|
3202
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3203
|
+
|
3204
|
+
if input_ndim == 4:
|
3205
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
3206
|
+
|
3207
|
+
if attn.residual_connection:
|
3208
|
+
hidden_states = hidden_states + residual
|
3209
|
+
|
3210
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3211
|
+
|
3212
|
+
return hidden_states
|
3213
|
+
|
3214
|
+
|
3215
|
+
class AttnProcessor2_0:
|
3216
|
+
r"""
|
3217
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
3218
|
+
"""
|
3219
|
+
|
3220
|
+
def __init__(self):
|
3221
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
3222
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
3223
|
+
|
3224
|
+
def __call__(
|
3225
|
+
self,
|
3226
|
+
attn: Attention,
|
3227
|
+
hidden_states: torch.Tensor,
|
3228
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3229
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3230
|
+
temb: Optional[torch.Tensor] = None,
|
3231
|
+
*args,
|
3232
|
+
**kwargs,
|
3233
|
+
) -> torch.Tensor:
|
3234
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3235
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3236
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
3237
|
+
|
3238
|
+
residual = hidden_states
|
3239
|
+
if attn.spatial_norm is not None:
|
3240
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
3241
|
+
|
3242
|
+
input_ndim = hidden_states.ndim
|
3243
|
+
|
3244
|
+
if input_ndim == 4:
|
3245
|
+
batch_size, channel, height, width = hidden_states.shape
|
3246
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3247
|
+
|
3248
|
+
batch_size, sequence_length, _ = (
|
3249
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3250
|
+
)
|
3251
|
+
|
3252
|
+
if attention_mask is not None:
|
3253
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3254
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
3255
|
+
# (batch, heads, source_length, target_length)
|
3256
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3257
|
+
|
3258
|
+
if attn.group_norm is not None:
|
3259
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
3260
|
+
|
3261
|
+
query = attn.to_q(hidden_states)
|
3262
|
+
|
3263
|
+
if encoder_hidden_states is None:
|
3264
|
+
encoder_hidden_states = hidden_states
|
3265
|
+
elif attn.norm_cross:
|
3266
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
3267
|
+
|
3268
|
+
key = attn.to_k(encoder_hidden_states)
|
3269
|
+
value = attn.to_v(encoder_hidden_states)
|
3270
|
+
|
3271
|
+
inner_dim = key.shape[-1]
|
3272
|
+
head_dim = inner_dim // attn.heads
|
3273
|
+
|
3274
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3275
|
+
|
3276
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3277
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3278
|
+
|
3279
|
+
if attn.norm_q is not None:
|
3280
|
+
query = attn.norm_q(query)
|
3281
|
+
if attn.norm_k is not None:
|
3282
|
+
key = attn.norm_k(key)
|
3283
|
+
|
3284
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
3285
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3286
|
+
hidden_states = F.scaled_dot_product_attention(
|
3287
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
3288
|
+
)
|
3289
|
+
|
3290
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3291
|
+
hidden_states = hidden_states.to(query.dtype)
|
3292
|
+
|
3293
|
+
# linear proj
|
3294
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3295
|
+
# dropout
|
3296
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3297
|
+
|
3298
|
+
if input_ndim == 4:
|
3299
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
3300
|
+
|
3301
|
+
if attn.residual_connection:
|
3302
|
+
hidden_states = hidden_states + residual
|
3303
|
+
|
3304
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3305
|
+
|
3306
|
+
return hidden_states
|
3307
|
+
|
3308
|
+
|
3309
|
+
class XLAFlashAttnProcessor2_0:
|
3310
|
+
r"""
|
3311
|
+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
3312
|
+
"""
|
3313
|
+
|
3314
|
+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
3315
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
3316
|
+
raise ImportError(
|
3317
|
+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3318
|
+
)
|
3319
|
+
if is_torch_xla_version("<", "2.3"):
|
3320
|
+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
3321
|
+
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
3322
|
+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
3323
|
+
self.partition_spec = partition_spec
|
3324
|
+
|
3325
|
+
def __call__(
|
3326
|
+
self,
|
3327
|
+
attn: Attention,
|
3328
|
+
hidden_states: torch.Tensor,
|
3329
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3330
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3331
|
+
temb: Optional[torch.Tensor] = None,
|
3332
|
+
*args,
|
3333
|
+
**kwargs,
|
3334
|
+
) -> torch.Tensor:
|
3335
|
+
residual = hidden_states
|
3336
|
+
if attn.spatial_norm is not None:
|
3337
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
3338
|
+
|
3339
|
+
input_ndim = hidden_states.ndim
|
3340
|
+
|
3341
|
+
if input_ndim == 4:
|
3342
|
+
batch_size, channel, height, width = hidden_states.shape
|
3343
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3344
|
+
|
3345
|
+
batch_size, sequence_length, _ = (
|
3346
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3347
|
+
)
|
3348
|
+
|
3349
|
+
if attention_mask is not None:
|
3350
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3351
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
3352
|
+
# (batch, heads, source_length, target_length)
|
3353
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3354
|
+
|
3355
|
+
if attn.group_norm is not None:
|
3356
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
3357
|
+
|
3358
|
+
query = attn.to_q(hidden_states)
|
3359
|
+
|
3360
|
+
if encoder_hidden_states is None:
|
3361
|
+
encoder_hidden_states = hidden_states
|
3362
|
+
elif attn.norm_cross:
|
3363
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
3364
|
+
|
3365
|
+
key = attn.to_k(encoder_hidden_states)
|
3366
|
+
value = attn.to_v(encoder_hidden_states)
|
3367
|
+
|
3368
|
+
inner_dim = key.shape[-1]
|
3369
|
+
head_dim = inner_dim // attn.heads
|
3370
|
+
|
3371
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3372
|
+
|
3373
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3374
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3375
|
+
|
3376
|
+
if attn.norm_q is not None:
|
3377
|
+
query = attn.norm_q(query)
|
3378
|
+
if attn.norm_k is not None:
|
3379
|
+
key = attn.norm_k(key)
|
3380
|
+
|
3381
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
3382
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
3383
|
+
if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
|
3384
|
+
if attention_mask is not None:
|
3385
|
+
attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
|
3386
|
+
# Convert mask to float and replace 0s with -inf and 1s with 0
|
3387
|
+
attention_mask = (
|
3388
|
+
attention_mask.float()
|
3389
|
+
.masked_fill(attention_mask == 0, float("-inf"))
|
3390
|
+
.masked_fill(attention_mask == 1, float(0.0))
|
3391
|
+
)
|
3392
|
+
|
3393
|
+
# Apply attention mask to key
|
3394
|
+
key = key + attention_mask
|
3395
|
+
query /= math.sqrt(query.shape[3])
|
3396
|
+
partition_spec = self.partition_spec if is_spmd() else None
|
3397
|
+
hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
|
3398
|
+
else:
|
3399
|
+
logger.warning(
|
3400
|
+
"Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
|
3401
|
+
)
|
2289
3402
|
hidden_states = F.scaled_dot_product_attention(
|
2290
3403
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2291
3404
|
)
|
@@ -2309,9 +3422,9 @@ class AttnProcessorNPU:
|
|
2309
3422
|
return hidden_states
|
2310
3423
|
|
2311
3424
|
|
2312
|
-
class
|
3425
|
+
class MochiVaeAttnProcessor2_0:
|
2313
3426
|
r"""
|
2314
|
-
|
3427
|
+
Attention processor used in Mochi VAE.
|
2315
3428
|
"""
|
2316
3429
|
|
2317
3430
|
def __init__(self):
|
@@ -2324,23 +3437,9 @@ class AttnProcessor2_0:
|
|
2324
3437
|
hidden_states: torch.Tensor,
|
2325
3438
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2326
3439
|
attention_mask: Optional[torch.Tensor] = None,
|
2327
|
-
temb: Optional[torch.Tensor] = None,
|
2328
|
-
*args,
|
2329
|
-
**kwargs,
|
2330
3440
|
) -> torch.Tensor:
|
2331
|
-
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2332
|
-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2333
|
-
deprecate("scale", "1.0.0", deprecation_message)
|
2334
|
-
|
2335
3441
|
residual = hidden_states
|
2336
|
-
|
2337
|
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
2338
|
-
|
2339
|
-
input_ndim = hidden_states.ndim
|
2340
|
-
|
2341
|
-
if input_ndim == 4:
|
2342
|
-
batch_size, channel, height, width = hidden_states.shape
|
2343
|
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
3442
|
+
is_single_frame = hidden_states.shape[1] == 1
|
2344
3443
|
|
2345
3444
|
batch_size, sequence_length, _ = (
|
2346
3445
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
@@ -2352,15 +3451,24 @@ class AttnProcessor2_0:
|
|
2352
3451
|
# (batch, heads, source_length, target_length)
|
2353
3452
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
2354
3453
|
|
2355
|
-
if
|
2356
|
-
hidden_states = attn.
|
3454
|
+
if is_single_frame:
|
3455
|
+
hidden_states = attn.to_v(hidden_states)
|
3456
|
+
|
3457
|
+
# linear proj
|
3458
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3459
|
+
# dropout
|
3460
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3461
|
+
|
3462
|
+
if attn.residual_connection:
|
3463
|
+
hidden_states = hidden_states + residual
|
3464
|
+
|
3465
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3466
|
+
return hidden_states
|
2357
3467
|
|
2358
3468
|
query = attn.to_q(hidden_states)
|
2359
3469
|
|
2360
3470
|
if encoder_hidden_states is None:
|
2361
3471
|
encoder_hidden_states = hidden_states
|
2362
|
-
elif attn.norm_cross:
|
2363
|
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
2364
3472
|
|
2365
3473
|
key = attn.to_k(encoder_hidden_states)
|
2366
3474
|
value = attn.to_v(encoder_hidden_states)
|
@@ -2369,7 +3477,6 @@ class AttnProcessor2_0:
|
|
2369
3477
|
head_dim = inner_dim // attn.heads
|
2370
3478
|
|
2371
3479
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2372
|
-
|
2373
3480
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2374
3481
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2375
3482
|
|
@@ -2381,7 +3488,7 @@ class AttnProcessor2_0:
|
|
2381
3488
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2382
3489
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
2383
3490
|
hidden_states = F.scaled_dot_product_attention(
|
2384
|
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=
|
3491
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
|
2385
3492
|
)
|
2386
3493
|
|
2387
3494
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
@@ -2392,9 +3499,6 @@ class AttnProcessor2_0:
|
|
2392
3499
|
# dropout
|
2393
3500
|
hidden_states = attn.to_out[1](hidden_states)
|
2394
3501
|
|
2395
|
-
if input_ndim == 4:
|
2396
|
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
2397
|
-
|
2398
3502
|
if attn.residual_connection:
|
2399
3503
|
hidden_states = hidden_states + residual
|
2400
3504
|
|
@@ -3597,34 +4701,232 @@ class SpatialNorm(nn.Module):
|
|
3597
4701
|
"""
|
3598
4702
|
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
|
3599
4703
|
|
3600
|
-
Args:
|
3601
|
-
f_channels (`int`):
|
3602
|
-
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
3603
|
-
zq_channels (`int`):
|
3604
|
-
The number of channels for the quantized vector as described in the paper.
|
3605
|
-
"""
|
4704
|
+
Args:
|
4705
|
+
f_channels (`int`):
|
4706
|
+
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
4707
|
+
zq_channels (`int`):
|
4708
|
+
The number of channels for the quantized vector as described in the paper.
|
4709
|
+
"""
|
4710
|
+
|
4711
|
+
def __init__(
|
4712
|
+
self,
|
4713
|
+
f_channels: int,
|
4714
|
+
zq_channels: int,
|
4715
|
+
):
|
4716
|
+
super().__init__()
|
4717
|
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
4718
|
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4719
|
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4720
|
+
|
4721
|
+
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
4722
|
+
f_size = f.shape[-2:]
|
4723
|
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
4724
|
+
norm_f = self.norm_layer(f)
|
4725
|
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
4726
|
+
return new_f
|
4727
|
+
|
4728
|
+
|
4729
|
+
class IPAdapterAttnProcessor(nn.Module):
|
4730
|
+
r"""
|
4731
|
+
Attention processor for Multiple IP-Adapters.
|
4732
|
+
|
4733
|
+
Args:
|
4734
|
+
hidden_size (`int`):
|
4735
|
+
The hidden size of the attention layer.
|
4736
|
+
cross_attention_dim (`int`):
|
4737
|
+
The number of channels in the `encoder_hidden_states`.
|
4738
|
+
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
4739
|
+
The context length of the image features.
|
4740
|
+
scale (`float` or List[`float`], defaults to 1.0):
|
4741
|
+
the weight scale of image prompt.
|
4742
|
+
"""
|
4743
|
+
|
4744
|
+
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
4745
|
+
super().__init__()
|
4746
|
+
|
4747
|
+
self.hidden_size = hidden_size
|
4748
|
+
self.cross_attention_dim = cross_attention_dim
|
4749
|
+
|
4750
|
+
if not isinstance(num_tokens, (tuple, list)):
|
4751
|
+
num_tokens = [num_tokens]
|
4752
|
+
self.num_tokens = num_tokens
|
4753
|
+
|
4754
|
+
if not isinstance(scale, list):
|
4755
|
+
scale = [scale] * len(num_tokens)
|
4756
|
+
if len(scale) != len(num_tokens):
|
4757
|
+
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
4758
|
+
self.scale = scale
|
4759
|
+
|
4760
|
+
self.to_k_ip = nn.ModuleList(
|
4761
|
+
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
4762
|
+
)
|
4763
|
+
self.to_v_ip = nn.ModuleList(
|
4764
|
+
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
4765
|
+
)
|
4766
|
+
|
4767
|
+
def __call__(
|
4768
|
+
self,
|
4769
|
+
attn: Attention,
|
4770
|
+
hidden_states: torch.Tensor,
|
4771
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
4772
|
+
attention_mask: Optional[torch.Tensor] = None,
|
4773
|
+
temb: Optional[torch.Tensor] = None,
|
4774
|
+
scale: float = 1.0,
|
4775
|
+
ip_adapter_masks: Optional[torch.Tensor] = None,
|
4776
|
+
):
|
4777
|
+
residual = hidden_states
|
4778
|
+
|
4779
|
+
# separate ip_hidden_states from encoder_hidden_states
|
4780
|
+
if encoder_hidden_states is not None:
|
4781
|
+
if isinstance(encoder_hidden_states, tuple):
|
4782
|
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
4783
|
+
else:
|
4784
|
+
deprecation_message = (
|
4785
|
+
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
4786
|
+
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
4787
|
+
)
|
4788
|
+
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
4789
|
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
4790
|
+
encoder_hidden_states, ip_hidden_states = (
|
4791
|
+
encoder_hidden_states[:, :end_pos, :],
|
4792
|
+
[encoder_hidden_states[:, end_pos:, :]],
|
4793
|
+
)
|
4794
|
+
|
4795
|
+
if attn.spatial_norm is not None:
|
4796
|
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
4797
|
+
|
4798
|
+
input_ndim = hidden_states.ndim
|
4799
|
+
|
4800
|
+
if input_ndim == 4:
|
4801
|
+
batch_size, channel, height, width = hidden_states.shape
|
4802
|
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
4803
|
+
|
4804
|
+
batch_size, sequence_length, _ = (
|
4805
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
4806
|
+
)
|
4807
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
4808
|
+
|
4809
|
+
if attn.group_norm is not None:
|
4810
|
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
4811
|
+
|
4812
|
+
query = attn.to_q(hidden_states)
|
4813
|
+
|
4814
|
+
if encoder_hidden_states is None:
|
4815
|
+
encoder_hidden_states = hidden_states
|
4816
|
+
elif attn.norm_cross:
|
4817
|
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
4818
|
+
|
4819
|
+
key = attn.to_k(encoder_hidden_states)
|
4820
|
+
value = attn.to_v(encoder_hidden_states)
|
4821
|
+
|
4822
|
+
query = attn.head_to_batch_dim(query)
|
4823
|
+
key = attn.head_to_batch_dim(key)
|
4824
|
+
value = attn.head_to_batch_dim(value)
|
4825
|
+
|
4826
|
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
4827
|
+
hidden_states = torch.bmm(attention_probs, value)
|
4828
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
4829
|
+
|
4830
|
+
if ip_adapter_masks is not None:
|
4831
|
+
if not isinstance(ip_adapter_masks, List):
|
4832
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
4833
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
4834
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
4835
|
+
raise ValueError(
|
4836
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
4837
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
4838
|
+
f"({len(ip_hidden_states)})"
|
4839
|
+
)
|
4840
|
+
else:
|
4841
|
+
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
4842
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
4843
|
+
raise ValueError(
|
4844
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
4845
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
4846
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
4847
|
+
)
|
4848
|
+
if mask.shape[1] != ip_state.shape[1]:
|
4849
|
+
raise ValueError(
|
4850
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
4851
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
4852
|
+
)
|
4853
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
4854
|
+
raise ValueError(
|
4855
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
4856
|
+
f"number of scales ({len(scale)}) at index {index}"
|
4857
|
+
)
|
4858
|
+
else:
|
4859
|
+
ip_adapter_masks = [None] * len(self.scale)
|
4860
|
+
|
4861
|
+
# for ip-adapter
|
4862
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
4863
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
4864
|
+
):
|
4865
|
+
skip = False
|
4866
|
+
if isinstance(scale, list):
|
4867
|
+
if all(s == 0 for s in scale):
|
4868
|
+
skip = True
|
4869
|
+
elif scale == 0:
|
4870
|
+
skip = True
|
4871
|
+
if not skip:
|
4872
|
+
if mask is not None:
|
4873
|
+
if not isinstance(scale, list):
|
4874
|
+
scale = [scale] * mask.shape[1]
|
4875
|
+
|
4876
|
+
current_num_images = mask.shape[1]
|
4877
|
+
for i in range(current_num_images):
|
4878
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
4879
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
4880
|
+
|
4881
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
4882
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
4883
|
+
|
4884
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
4885
|
+
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
4886
|
+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
4887
|
+
|
4888
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
4889
|
+
mask[:, i, :, :],
|
4890
|
+
batch_size,
|
4891
|
+
_current_ip_hidden_states.shape[1],
|
4892
|
+
_current_ip_hidden_states.shape[2],
|
4893
|
+
)
|
4894
|
+
|
4895
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
4896
|
+
|
4897
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
4898
|
+
else:
|
4899
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
4900
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
4901
|
+
|
4902
|
+
ip_key = attn.head_to_batch_dim(ip_key)
|
4903
|
+
ip_value = attn.head_to_batch_dim(ip_value)
|
4904
|
+
|
4905
|
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
4906
|
+
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
4907
|
+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
4908
|
+
|
4909
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
4910
|
+
|
4911
|
+
# linear proj
|
4912
|
+
hidden_states = attn.to_out[0](hidden_states)
|
4913
|
+
# dropout
|
4914
|
+
hidden_states = attn.to_out[1](hidden_states)
|
4915
|
+
|
4916
|
+
if input_ndim == 4:
|
4917
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
4918
|
+
|
4919
|
+
if attn.residual_connection:
|
4920
|
+
hidden_states = hidden_states + residual
|
3606
4921
|
|
3607
|
-
|
3608
|
-
self,
|
3609
|
-
f_channels: int,
|
3610
|
-
zq_channels: int,
|
3611
|
-
):
|
3612
|
-
super().__init__()
|
3613
|
-
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
3614
|
-
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
3615
|
-
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
4922
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
3616
4923
|
|
3617
|
-
|
3618
|
-
f_size = f.shape[-2:]
|
3619
|
-
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
3620
|
-
norm_f = self.norm_layer(f)
|
3621
|
-
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
3622
|
-
return new_f
|
4924
|
+
return hidden_states
|
3623
4925
|
|
3624
4926
|
|
3625
|
-
class
|
4927
|
+
class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
3626
4928
|
r"""
|
3627
|
-
Attention processor for
|
4929
|
+
Attention processor for IP-Adapter for PyTorch 2.0.
|
3628
4930
|
|
3629
4931
|
Args:
|
3630
4932
|
hidden_size (`int`):
|
@@ -3633,13 +4935,18 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3633
4935
|
The number of channels in the `encoder_hidden_states`.
|
3634
4936
|
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
3635
4937
|
The context length of the image features.
|
3636
|
-
scale (`float` or List[
|
4938
|
+
scale (`float` or `List[float]`, defaults to 1.0):
|
3637
4939
|
the weight scale of image prompt.
|
3638
4940
|
"""
|
3639
4941
|
|
3640
4942
|
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
|
3641
4943
|
super().__init__()
|
3642
4944
|
|
4945
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
4946
|
+
raise ImportError(
|
4947
|
+
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
4948
|
+
)
|
4949
|
+
|
3643
4950
|
self.hidden_size = hidden_size
|
3644
4951
|
self.cross_attention_dim = cross_attention_dim
|
3645
4952
|
|
@@ -3700,7 +5007,12 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3700
5007
|
batch_size, sequence_length, _ = (
|
3701
5008
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3702
5009
|
)
|
3703
|
-
|
5010
|
+
|
5011
|
+
if attention_mask is not None:
|
5012
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
5013
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
5014
|
+
# (batch, heads, source_length, target_length)
|
5015
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3704
5016
|
|
3705
5017
|
if attn.group_norm is not None:
|
3706
5018
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
@@ -3715,13 +5027,22 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3715
5027
|
key = attn.to_k(encoder_hidden_states)
|
3716
5028
|
value = attn.to_v(encoder_hidden_states)
|
3717
5029
|
|
3718
|
-
|
3719
|
-
|
3720
|
-
value = attn.head_to_batch_dim(value)
|
5030
|
+
inner_dim = key.shape[-1]
|
5031
|
+
head_dim = inner_dim // attn.heads
|
3721
5032
|
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
5033
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5034
|
+
|
5035
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5036
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5037
|
+
|
5038
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5039
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5040
|
+
hidden_states = F.scaled_dot_product_attention(
|
5041
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
5042
|
+
)
|
5043
|
+
|
5044
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
5045
|
+
hidden_states = hidden_states.to(query.dtype)
|
3725
5046
|
|
3726
5047
|
if ip_adapter_masks is not None:
|
3727
5048
|
if not isinstance(ip_adapter_masks, List):
|
@@ -3774,12 +5095,19 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3774
5095
|
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
3775
5096
|
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
3776
5097
|
|
3777
|
-
ip_key = attn.
|
3778
|
-
ip_value = attn.
|
5098
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5099
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3779
5100
|
|
3780
|
-
|
3781
|
-
|
3782
|
-
_current_ip_hidden_states =
|
5101
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5102
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5103
|
+
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
5104
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
5105
|
+
)
|
5106
|
+
|
5107
|
+
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
5108
|
+
batch_size, -1, attn.heads * head_dim
|
5109
|
+
)
|
5110
|
+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
3783
5111
|
|
3784
5112
|
mask_downsample = IPAdapterMaskProcessor.downsample(
|
3785
5113
|
mask[:, i, :, :],
|
@@ -3789,18 +5117,24 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3789
5117
|
)
|
3790
5118
|
|
3791
5119
|
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
3792
|
-
|
3793
5120
|
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
3794
5121
|
else:
|
3795
5122
|
ip_key = to_k_ip(current_ip_hidden_states)
|
3796
5123
|
ip_value = to_v_ip(current_ip_hidden_states)
|
3797
5124
|
|
3798
|
-
ip_key = attn.
|
3799
|
-
ip_value = attn.
|
5125
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5126
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3800
5127
|
|
3801
|
-
|
3802
|
-
|
3803
|
-
current_ip_hidden_states =
|
5128
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
5129
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
5130
|
+
current_ip_hidden_states = F.scaled_dot_product_attention(
|
5131
|
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
5132
|
+
)
|
5133
|
+
|
5134
|
+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
5135
|
+
batch_size, -1, attn.heads * head_dim
|
5136
|
+
)
|
5137
|
+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
3804
5138
|
|
3805
5139
|
hidden_states = hidden_states + scale * current_ip_hidden_states
|
3806
5140
|
|
@@ -3820,9 +5154,9 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
3820
5154
|
return hidden_states
|
3821
5155
|
|
3822
5156
|
|
3823
|
-
class
|
5157
|
+
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
3824
5158
|
r"""
|
3825
|
-
Attention processor for IP-Adapter
|
5159
|
+
Attention processor for IP-Adapter using xFormers.
|
3826
5160
|
|
3827
5161
|
Args:
|
3828
5162
|
hidden_size (`int`):
|
@@ -3833,18 +5167,26 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3833
5167
|
The context length of the image features.
|
3834
5168
|
scale (`float` or `List[float]`, defaults to 1.0):
|
3835
5169
|
the weight scale of image prompt.
|
5170
|
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
5171
|
+
The base
|
5172
|
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
5173
|
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
5174
|
+
operator.
|
3836
5175
|
"""
|
3837
5176
|
|
3838
|
-
def __init__(
|
5177
|
+
def __init__(
|
5178
|
+
self,
|
5179
|
+
hidden_size,
|
5180
|
+
cross_attention_dim=None,
|
5181
|
+
num_tokens=(4,),
|
5182
|
+
scale=1.0,
|
5183
|
+
attention_op: Optional[Callable] = None,
|
5184
|
+
):
|
3839
5185
|
super().__init__()
|
3840
5186
|
|
3841
|
-
if not hasattr(F, "scaled_dot_product_attention"):
|
3842
|
-
raise ImportError(
|
3843
|
-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3844
|
-
)
|
3845
|
-
|
3846
5187
|
self.hidden_size = hidden_size
|
3847
5188
|
self.cross_attention_dim = cross_attention_dim
|
5189
|
+
self.attention_op = attention_op
|
3848
5190
|
|
3849
5191
|
if not isinstance(num_tokens, (tuple, list)):
|
3850
5192
|
num_tokens = [num_tokens]
|
@@ -3857,21 +5199,21 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3857
5199
|
self.scale = scale
|
3858
5200
|
|
3859
5201
|
self.to_k_ip = nn.ModuleList(
|
3860
|
-
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
5202
|
+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
3861
5203
|
)
|
3862
5204
|
self.to_v_ip = nn.ModuleList(
|
3863
|
-
[nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
5205
|
+
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
3864
5206
|
)
|
3865
5207
|
|
3866
5208
|
def __call__(
|
3867
5209
|
self,
|
3868
5210
|
attn: Attention,
|
3869
|
-
hidden_states: torch.
|
3870
|
-
encoder_hidden_states: Optional[torch.
|
3871
|
-
attention_mask: Optional[torch.
|
3872
|
-
temb: Optional[torch.
|
5211
|
+
hidden_states: torch.FloatTensor,
|
5212
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
5213
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
5214
|
+
temb: Optional[torch.FloatTensor] = None,
|
3873
5215
|
scale: float = 1.0,
|
3874
|
-
ip_adapter_masks: Optional[torch.
|
5216
|
+
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
3875
5217
|
):
|
3876
5218
|
residual = hidden_states
|
3877
5219
|
|
@@ -3906,9 +5248,14 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3906
5248
|
|
3907
5249
|
if attention_mask is not None:
|
3908
5250
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
3909
|
-
#
|
3910
|
-
#
|
3911
|
-
|
5251
|
+
# expand our mask's singleton query_tokens dimension:
|
5252
|
+
# [batch*heads, 1, key_tokens] ->
|
5253
|
+
# [batch*heads, query_tokens, key_tokens]
|
5254
|
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
5255
|
+
# [batch*heads, query_tokens, key_tokens]
|
5256
|
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
5257
|
+
_, query_tokens, _ = hidden_states.shape
|
5258
|
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
3912
5259
|
|
3913
5260
|
if attn.group_norm is not None:
|
3914
5261
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
@@ -3923,131 +5270,291 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
3923
5270
|
key = attn.to_k(encoder_hidden_states)
|
3924
5271
|
value = attn.to_v(encoder_hidden_states)
|
3925
5272
|
|
5273
|
+
query = attn.head_to_batch_dim(query).contiguous()
|
5274
|
+
key = attn.head_to_batch_dim(key).contiguous()
|
5275
|
+
value = attn.head_to_batch_dim(value).contiguous()
|
5276
|
+
|
5277
|
+
hidden_states = xformers.ops.memory_efficient_attention(
|
5278
|
+
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
5279
|
+
)
|
5280
|
+
hidden_states = hidden_states.to(query.dtype)
|
5281
|
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
5282
|
+
|
5283
|
+
if ip_hidden_states:
|
5284
|
+
if ip_adapter_masks is not None:
|
5285
|
+
if not isinstance(ip_adapter_masks, List):
|
5286
|
+
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
5287
|
+
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
5288
|
+
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
5289
|
+
raise ValueError(
|
5290
|
+
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
5291
|
+
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
5292
|
+
f"({len(ip_hidden_states)})"
|
5293
|
+
)
|
5294
|
+
else:
|
5295
|
+
for index, (mask, scale, ip_state) in enumerate(
|
5296
|
+
zip(ip_adapter_masks, self.scale, ip_hidden_states)
|
5297
|
+
):
|
5298
|
+
if mask is None:
|
5299
|
+
continue
|
5300
|
+
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
5301
|
+
raise ValueError(
|
5302
|
+
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
5303
|
+
"[1, num_images_for_ip_adapter, height, width]."
|
5304
|
+
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
5305
|
+
)
|
5306
|
+
if mask.shape[1] != ip_state.shape[1]:
|
5307
|
+
raise ValueError(
|
5308
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
5309
|
+
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
5310
|
+
)
|
5311
|
+
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
5312
|
+
raise ValueError(
|
5313
|
+
f"Number of masks ({mask.shape[1]}) does not match "
|
5314
|
+
f"number of scales ({len(scale)}) at index {index}"
|
5315
|
+
)
|
5316
|
+
else:
|
5317
|
+
ip_adapter_masks = [None] * len(self.scale)
|
5318
|
+
|
5319
|
+
# for ip-adapter
|
5320
|
+
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
5321
|
+
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
5322
|
+
):
|
5323
|
+
skip = False
|
5324
|
+
if isinstance(scale, list):
|
5325
|
+
if all(s == 0 for s in scale):
|
5326
|
+
skip = True
|
5327
|
+
elif scale == 0:
|
5328
|
+
skip = True
|
5329
|
+
if not skip:
|
5330
|
+
if mask is not None:
|
5331
|
+
mask = mask.to(torch.float16)
|
5332
|
+
if not isinstance(scale, list):
|
5333
|
+
scale = [scale] * mask.shape[1]
|
5334
|
+
|
5335
|
+
current_num_images = mask.shape[1]
|
5336
|
+
for i in range(current_num_images):
|
5337
|
+
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
5338
|
+
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
5339
|
+
|
5340
|
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
5341
|
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
5342
|
+
|
5343
|
+
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
5344
|
+
query, ip_key, ip_value, op=self.attention_op
|
5345
|
+
)
|
5346
|
+
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
5347
|
+
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
5348
|
+
|
5349
|
+
mask_downsample = IPAdapterMaskProcessor.downsample(
|
5350
|
+
mask[:, i, :, :],
|
5351
|
+
batch_size,
|
5352
|
+
_current_ip_hidden_states.shape[1],
|
5353
|
+
_current_ip_hidden_states.shape[2],
|
5354
|
+
)
|
5355
|
+
|
5356
|
+
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
5357
|
+
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
5358
|
+
else:
|
5359
|
+
ip_key = to_k_ip(current_ip_hidden_states)
|
5360
|
+
ip_value = to_v_ip(current_ip_hidden_states)
|
5361
|
+
|
5362
|
+
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
5363
|
+
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
5364
|
+
|
5365
|
+
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
5366
|
+
query, ip_key, ip_value, op=self.attention_op
|
5367
|
+
)
|
5368
|
+
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
5369
|
+
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
5370
|
+
|
5371
|
+
hidden_states = hidden_states + scale * current_ip_hidden_states
|
5372
|
+
|
5373
|
+
# linear proj
|
5374
|
+
hidden_states = attn.to_out[0](hidden_states)
|
5375
|
+
# dropout
|
5376
|
+
hidden_states = attn.to_out[1](hidden_states)
|
5377
|
+
|
5378
|
+
if input_ndim == 4:
|
5379
|
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
5380
|
+
|
5381
|
+
if attn.residual_connection:
|
5382
|
+
hidden_states = hidden_states + residual
|
5383
|
+
|
5384
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
5385
|
+
|
5386
|
+
return hidden_states
|
5387
|
+
|
5388
|
+
|
5389
|
+
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
5390
|
+
"""
|
5391
|
+
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
|
5392
|
+
additional image-based information and timestep embeddings.
|
5393
|
+
|
5394
|
+
Args:
|
5395
|
+
hidden_size (`int`):
|
5396
|
+
The number of hidden channels.
|
5397
|
+
ip_hidden_states_dim (`int`):
|
5398
|
+
The image feature dimension.
|
5399
|
+
head_dim (`int`):
|
5400
|
+
The number of head channels.
|
5401
|
+
timesteps_emb_dim (`int`, defaults to 1280):
|
5402
|
+
The number of input channels for timestep embedding.
|
5403
|
+
scale (`float`, defaults to 0.5):
|
5404
|
+
IP-Adapter scale.
|
5405
|
+
"""
|
5406
|
+
|
5407
|
+
def __init__(
|
5408
|
+
self,
|
5409
|
+
hidden_size: int,
|
5410
|
+
ip_hidden_states_dim: int,
|
5411
|
+
head_dim: int,
|
5412
|
+
timesteps_emb_dim: int = 1280,
|
5413
|
+
scale: float = 0.5,
|
5414
|
+
):
|
5415
|
+
super().__init__()
|
5416
|
+
|
5417
|
+
# To prevent circular import
|
5418
|
+
from .normalization import AdaLayerNorm, RMSNorm
|
5419
|
+
|
5420
|
+
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
|
5421
|
+
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
5422
|
+
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
5423
|
+
self.norm_q = RMSNorm(head_dim, 1e-6)
|
5424
|
+
self.norm_k = RMSNorm(head_dim, 1e-6)
|
5425
|
+
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
5426
|
+
self.scale = scale
|
5427
|
+
|
5428
|
+
def __call__(
|
5429
|
+
self,
|
5430
|
+
attn: Attention,
|
5431
|
+
hidden_states: torch.FloatTensor,
|
5432
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
5433
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
5434
|
+
ip_hidden_states: torch.FloatTensor = None,
|
5435
|
+
temb: torch.FloatTensor = None,
|
5436
|
+
) -> torch.FloatTensor:
|
5437
|
+
"""
|
5438
|
+
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
|
5439
|
+
|
5440
|
+
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
|
5441
|
+
|
5442
|
+
Args:
|
5443
|
+
attn (`Attention`):
|
5444
|
+
Attention instance.
|
5445
|
+
hidden_states (`torch.FloatTensor`):
|
5446
|
+
Input `hidden_states`.
|
5447
|
+
encoder_hidden_states (`torch.FloatTensor`, *optional*):
|
5448
|
+
The encoder hidden states.
|
5449
|
+
attention_mask (`torch.FloatTensor`, *optional*):
|
5450
|
+
Attention mask.
|
5451
|
+
ip_hidden_states (`torch.FloatTensor`, *optional*):
|
5452
|
+
Image embeddings.
|
5453
|
+
temb (`torch.FloatTensor`, *optional*):
|
5454
|
+
Timestep embeddings.
|
5455
|
+
|
5456
|
+
Returns:
|
5457
|
+
`torch.FloatTensor`: Output hidden states.
|
5458
|
+
"""
|
5459
|
+
residual = hidden_states
|
5460
|
+
|
5461
|
+
batch_size = hidden_states.shape[0]
|
5462
|
+
|
5463
|
+
# `sample` projections.
|
5464
|
+
query = attn.to_q(hidden_states)
|
5465
|
+
key = attn.to_k(hidden_states)
|
5466
|
+
value = attn.to_v(hidden_states)
|
5467
|
+
|
3926
5468
|
inner_dim = key.shape[-1]
|
3927
5469
|
head_dim = inner_dim // attn.heads
|
3928
5470
|
|
3929
5471
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3930
|
-
|
3931
5472
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3932
5473
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5474
|
+
img_query = query
|
5475
|
+
img_key = key
|
5476
|
+
img_value = value
|
3933
5477
|
|
3934
|
-
|
3935
|
-
|
3936
|
-
|
3937
|
-
|
3938
|
-
)
|
5478
|
+
if attn.norm_q is not None:
|
5479
|
+
query = attn.norm_q(query)
|
5480
|
+
if attn.norm_k is not None:
|
5481
|
+
key = attn.norm_k(key)
|
3939
5482
|
|
3940
|
-
|
3941
|
-
|
5483
|
+
# `context` projections.
|
5484
|
+
if encoder_hidden_states is not None:
|
5485
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
5486
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
5487
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
3942
5488
|
|
3943
|
-
|
3944
|
-
|
3945
|
-
|
3946
|
-
|
3947
|
-
|
3948
|
-
|
3949
|
-
|
3950
|
-
|
3951
|
-
|
3952
|
-
)
|
3953
|
-
else:
|
3954
|
-
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
3955
|
-
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
3956
|
-
raise ValueError(
|
3957
|
-
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
3958
|
-
"[1, num_images_for_ip_adapter, height, width]."
|
3959
|
-
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
3960
|
-
)
|
3961
|
-
if mask.shape[1] != ip_state.shape[1]:
|
3962
|
-
raise ValueError(
|
3963
|
-
f"Number of masks ({mask.shape[1]}) does not match "
|
3964
|
-
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
3965
|
-
)
|
3966
|
-
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
3967
|
-
raise ValueError(
|
3968
|
-
f"Number of masks ({mask.shape[1]}) does not match "
|
3969
|
-
f"number of scales ({len(scale)}) at index {index}"
|
3970
|
-
)
|
3971
|
-
else:
|
3972
|
-
ip_adapter_masks = [None] * len(self.scale)
|
5489
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
5490
|
+
batch_size, -1, attn.heads, head_dim
|
5491
|
+
).transpose(1, 2)
|
5492
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
5493
|
+
batch_size, -1, attn.heads, head_dim
|
5494
|
+
).transpose(1, 2)
|
5495
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
5496
|
+
batch_size, -1, attn.heads, head_dim
|
5497
|
+
).transpose(1, 2)
|
3973
5498
|
|
3974
|
-
|
3975
|
-
|
3976
|
-
|
3977
|
-
|
3978
|
-
skip = False
|
3979
|
-
if isinstance(scale, list):
|
3980
|
-
if all(s == 0 for s in scale):
|
3981
|
-
skip = True
|
3982
|
-
elif scale == 0:
|
3983
|
-
skip = True
|
3984
|
-
if not skip:
|
3985
|
-
if mask is not None:
|
3986
|
-
if not isinstance(scale, list):
|
3987
|
-
scale = [scale] * mask.shape[1]
|
5499
|
+
if attn.norm_added_q is not None:
|
5500
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
5501
|
+
if attn.norm_added_k is not None:
|
5502
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
3988
5503
|
|
3989
|
-
|
3990
|
-
|
3991
|
-
|
3992
|
-
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
5504
|
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
5505
|
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
5506
|
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
3993
5507
|
|
3994
|
-
|
3995
|
-
|
5508
|
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
5509
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
5510
|
+
hidden_states = hidden_states.to(query.dtype)
|
3996
5511
|
|
3997
|
-
|
3998
|
-
|
3999
|
-
|
4000
|
-
|
4001
|
-
|
5512
|
+
if encoder_hidden_states is not None:
|
5513
|
+
# Split the attention outputs.
|
5514
|
+
hidden_states, encoder_hidden_states = (
|
5515
|
+
hidden_states[:, : residual.shape[1]],
|
5516
|
+
hidden_states[:, residual.shape[1] :],
|
5517
|
+
)
|
5518
|
+
if not attn.context_pre_only:
|
5519
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
4002
5520
|
|
4003
|
-
|
4004
|
-
|
4005
|
-
|
4006
|
-
|
5521
|
+
# IP Adapter
|
5522
|
+
if self.scale != 0 and ip_hidden_states is not None:
|
5523
|
+
# Norm image features
|
5524
|
+
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
|
4007
5525
|
|
4008
|
-
|
4009
|
-
|
4010
|
-
|
4011
|
-
_current_ip_hidden_states.shape[1],
|
4012
|
-
_current_ip_hidden_states.shape[2],
|
4013
|
-
)
|
5526
|
+
# To k and v
|
5527
|
+
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
5528
|
+
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
4014
5529
|
|
4015
|
-
|
4016
|
-
|
4017
|
-
|
4018
|
-
ip_key = to_k_ip(current_ip_hidden_states)
|
4019
|
-
ip_value = to_v_ip(current_ip_hidden_states)
|
5530
|
+
# Reshape
|
5531
|
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
5532
|
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
4020
5533
|
|
4021
|
-
|
4022
|
-
|
5534
|
+
# Norm
|
5535
|
+
query = self.norm_q(img_query)
|
5536
|
+
img_key = self.norm_k(img_key)
|
5537
|
+
ip_key = self.norm_ip_k(ip_key)
|
4023
5538
|
|
4024
|
-
|
4025
|
-
|
4026
|
-
|
4027
|
-
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
4028
|
-
)
|
5539
|
+
# cat img
|
5540
|
+
key = torch.cat([img_key, ip_key], dim=2)
|
5541
|
+
value = torch.cat([img_value, ip_value], dim=2)
|
4029
5542
|
|
4030
|
-
|
4031
|
-
|
4032
|
-
|
4033
|
-
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
5543
|
+
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
5544
|
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
|
5545
|
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
4034
5546
|
|
4035
|
-
|
5547
|
+
hidden_states = hidden_states + ip_hidden_states * self.scale
|
4036
5548
|
|
4037
5549
|
# linear proj
|
4038
5550
|
hidden_states = attn.to_out[0](hidden_states)
|
4039
5551
|
# dropout
|
4040
5552
|
hidden_states = attn.to_out[1](hidden_states)
|
4041
5553
|
|
4042
|
-
if
|
4043
|
-
|
4044
|
-
|
4045
|
-
|
4046
|
-
hidden_states = hidden_states + residual
|
4047
|
-
|
4048
|
-
hidden_states = hidden_states / attn.rescale_output_factor
|
4049
|
-
|
4050
|
-
return hidden_states
|
5554
|
+
if encoder_hidden_states is not None:
|
5555
|
+
return hidden_states, encoder_hidden_states
|
5556
|
+
else:
|
5557
|
+
return hidden_states
|
4051
5558
|
|
4052
5559
|
|
4053
5560
|
class PAGIdentitySelfAttnProcessor2_0:
|
@@ -4252,22 +5759,98 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
|
|
4252
5759
|
return hidden_states
|
4253
5760
|
|
4254
5761
|
|
5762
|
+
class SanaMultiscaleAttnProcessor2_0:
|
5763
|
+
r"""
|
5764
|
+
Processor for implementing multiscale quadratic attention.
|
5765
|
+
"""
|
5766
|
+
|
5767
|
+
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
|
5768
|
+
height, width = hidden_states.shape[-2:]
|
5769
|
+
if height * width > attn.attention_head_dim:
|
5770
|
+
use_linear_attention = True
|
5771
|
+
else:
|
5772
|
+
use_linear_attention = False
|
5773
|
+
|
5774
|
+
residual = hidden_states
|
5775
|
+
|
5776
|
+
batch_size, _, height, width = list(hidden_states.size())
|
5777
|
+
original_dtype = hidden_states.dtype
|
5778
|
+
|
5779
|
+
hidden_states = hidden_states.movedim(1, -1)
|
5780
|
+
query = attn.to_q(hidden_states)
|
5781
|
+
key = attn.to_k(hidden_states)
|
5782
|
+
value = attn.to_v(hidden_states)
|
5783
|
+
hidden_states = torch.cat([query, key, value], dim=3)
|
5784
|
+
hidden_states = hidden_states.movedim(-1, 1)
|
5785
|
+
|
5786
|
+
multi_scale_qkv = [hidden_states]
|
5787
|
+
for block in attn.to_qkv_multiscale:
|
5788
|
+
multi_scale_qkv.append(block(hidden_states))
|
5789
|
+
|
5790
|
+
hidden_states = torch.cat(multi_scale_qkv, dim=1)
|
5791
|
+
|
5792
|
+
if use_linear_attention:
|
5793
|
+
# for linear attention upcast hidden_states to float32
|
5794
|
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
5795
|
+
|
5796
|
+
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
|
5797
|
+
|
5798
|
+
query, key, value = hidden_states.chunk(3, dim=2)
|
5799
|
+
query = attn.nonlinearity(query)
|
5800
|
+
key = attn.nonlinearity(key)
|
5801
|
+
|
5802
|
+
if use_linear_attention:
|
5803
|
+
hidden_states = attn.apply_linear_attention(query, key, value)
|
5804
|
+
hidden_states = hidden_states.to(dtype=original_dtype)
|
5805
|
+
else:
|
5806
|
+
hidden_states = attn.apply_quadratic_attention(query, key, value)
|
5807
|
+
|
5808
|
+
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
|
5809
|
+
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
5810
|
+
|
5811
|
+
if attn.norm_type == "rms_norm":
|
5812
|
+
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
5813
|
+
else:
|
5814
|
+
hidden_states = attn.norm_out(hidden_states)
|
5815
|
+
|
5816
|
+
if attn.residual_connection:
|
5817
|
+
hidden_states = hidden_states + residual
|
5818
|
+
|
5819
|
+
return hidden_states
|
5820
|
+
|
5821
|
+
|
4255
5822
|
class LoRAAttnProcessor:
|
5823
|
+
r"""
|
5824
|
+
Processor for implementing attention with LoRA.
|
5825
|
+
"""
|
5826
|
+
|
4256
5827
|
def __init__(self):
|
4257
5828
|
pass
|
4258
5829
|
|
4259
5830
|
|
4260
5831
|
class LoRAAttnProcessor2_0:
|
5832
|
+
r"""
|
5833
|
+
Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
|
5834
|
+
"""
|
5835
|
+
|
4261
5836
|
def __init__(self):
|
4262
5837
|
pass
|
4263
5838
|
|
4264
5839
|
|
4265
5840
|
class LoRAXFormersAttnProcessor:
|
5841
|
+
r"""
|
5842
|
+
Processor for implementing attention with LoRA using xFormers.
|
5843
|
+
"""
|
5844
|
+
|
4266
5845
|
def __init__(self):
|
4267
5846
|
pass
|
4268
5847
|
|
4269
5848
|
|
4270
5849
|
class LoRAAttnAddedKVProcessor:
|
5850
|
+
r"""
|
5851
|
+
Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
|
5852
|
+
"""
|
5853
|
+
|
4271
5854
|
def __init__(self):
|
4272
5855
|
pass
|
4273
5856
|
|
@@ -4283,6 +5866,165 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
|
4283
5866
|
super().__init__()
|
4284
5867
|
|
4285
5868
|
|
5869
|
+
class SanaLinearAttnProcessor2_0:
|
5870
|
+
r"""
|
5871
|
+
Processor for implementing scaled dot-product linear attention.
|
5872
|
+
"""
|
5873
|
+
|
5874
|
+
def __call__(
|
5875
|
+
self,
|
5876
|
+
attn: Attention,
|
5877
|
+
hidden_states: torch.Tensor,
|
5878
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5879
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5880
|
+
) -> torch.Tensor:
|
5881
|
+
original_dtype = hidden_states.dtype
|
5882
|
+
|
5883
|
+
if encoder_hidden_states is None:
|
5884
|
+
encoder_hidden_states = hidden_states
|
5885
|
+
|
5886
|
+
query = attn.to_q(hidden_states)
|
5887
|
+
key = attn.to_k(encoder_hidden_states)
|
5888
|
+
value = attn.to_v(encoder_hidden_states)
|
5889
|
+
|
5890
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5891
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5892
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5893
|
+
|
5894
|
+
query = F.relu(query)
|
5895
|
+
key = F.relu(key)
|
5896
|
+
|
5897
|
+
query, key, value = query.float(), key.float(), value.float()
|
5898
|
+
|
5899
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
5900
|
+
scores = torch.matmul(value, key)
|
5901
|
+
hidden_states = torch.matmul(scores, query)
|
5902
|
+
|
5903
|
+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
|
5904
|
+
hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
|
5905
|
+
hidden_states = hidden_states.to(original_dtype)
|
5906
|
+
|
5907
|
+
hidden_states = attn.to_out[0](hidden_states)
|
5908
|
+
hidden_states = attn.to_out[1](hidden_states)
|
5909
|
+
|
5910
|
+
if original_dtype == torch.float16:
|
5911
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
5912
|
+
|
5913
|
+
return hidden_states
|
5914
|
+
|
5915
|
+
|
5916
|
+
class PAGCFGSanaLinearAttnProcessor2_0:
|
5917
|
+
r"""
|
5918
|
+
Processor for implementing scaled dot-product linear attention.
|
5919
|
+
"""
|
5920
|
+
|
5921
|
+
def __call__(
|
5922
|
+
self,
|
5923
|
+
attn: Attention,
|
5924
|
+
hidden_states: torch.Tensor,
|
5925
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5926
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5927
|
+
) -> torch.Tensor:
|
5928
|
+
original_dtype = hidden_states.dtype
|
5929
|
+
|
5930
|
+
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
5931
|
+
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
5932
|
+
|
5933
|
+
query = attn.to_q(hidden_states_org)
|
5934
|
+
key = attn.to_k(hidden_states_org)
|
5935
|
+
value = attn.to_v(hidden_states_org)
|
5936
|
+
|
5937
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5938
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5939
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5940
|
+
|
5941
|
+
query = F.relu(query)
|
5942
|
+
key = F.relu(key)
|
5943
|
+
|
5944
|
+
query, key, value = query.float(), key.float(), value.float()
|
5945
|
+
|
5946
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
5947
|
+
scores = torch.matmul(value, key)
|
5948
|
+
hidden_states_org = torch.matmul(scores, query)
|
5949
|
+
|
5950
|
+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
|
5951
|
+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
|
5952
|
+
hidden_states_org = hidden_states_org.to(original_dtype)
|
5953
|
+
|
5954
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
5955
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
5956
|
+
|
5957
|
+
# perturbed path (identity attention)
|
5958
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
|
5959
|
+
|
5960
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
5961
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
5962
|
+
|
5963
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
5964
|
+
|
5965
|
+
if original_dtype == torch.float16:
|
5966
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
5967
|
+
|
5968
|
+
return hidden_states
|
5969
|
+
|
5970
|
+
|
5971
|
+
class PAGIdentitySanaLinearAttnProcessor2_0:
|
5972
|
+
r"""
|
5973
|
+
Processor for implementing scaled dot-product linear attention.
|
5974
|
+
"""
|
5975
|
+
|
5976
|
+
def __call__(
|
5977
|
+
self,
|
5978
|
+
attn: Attention,
|
5979
|
+
hidden_states: torch.Tensor,
|
5980
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
5981
|
+
attention_mask: Optional[torch.Tensor] = None,
|
5982
|
+
) -> torch.Tensor:
|
5983
|
+
original_dtype = hidden_states.dtype
|
5984
|
+
|
5985
|
+
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
5986
|
+
|
5987
|
+
query = attn.to_q(hidden_states_org)
|
5988
|
+
key = attn.to_k(hidden_states_org)
|
5989
|
+
value = attn.to_v(hidden_states_org)
|
5990
|
+
|
5991
|
+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5992
|
+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5993
|
+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5994
|
+
|
5995
|
+
query = F.relu(query)
|
5996
|
+
key = F.relu(key)
|
5997
|
+
|
5998
|
+
query, key, value = query.float(), key.float(), value.float()
|
5999
|
+
|
6000
|
+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
|
6001
|
+
scores = torch.matmul(value, key)
|
6002
|
+
hidden_states_org = torch.matmul(scores, query)
|
6003
|
+
|
6004
|
+
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
|
6005
|
+
hidden_states_org = hidden_states_org.float()
|
6006
|
+
|
6007
|
+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
|
6008
|
+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
|
6009
|
+
hidden_states_org = hidden_states_org.to(original_dtype)
|
6010
|
+
|
6011
|
+
hidden_states_org = attn.to_out[0](hidden_states_org)
|
6012
|
+
hidden_states_org = attn.to_out[1](hidden_states_org)
|
6013
|
+
|
6014
|
+
# perturbed path (identity attention)
|
6015
|
+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
|
6016
|
+
|
6017
|
+
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
6018
|
+
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
6019
|
+
|
6020
|
+
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
6021
|
+
|
6022
|
+
if original_dtype == torch.float16:
|
6023
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
6024
|
+
|
6025
|
+
return hidden_states
|
6026
|
+
|
6027
|
+
|
4286
6028
|
ADDED_KV_ATTENTION_PROCESSORS = (
|
4287
6029
|
AttnAddedKVProcessor,
|
4288
6030
|
SlicedAttnAddedKVProcessor,
|
@@ -4297,23 +6039,59 @@ CROSS_ATTENTION_PROCESSORS = (
|
|
4297
6039
|
SlicedAttnProcessor,
|
4298
6040
|
IPAdapterAttnProcessor,
|
4299
6041
|
IPAdapterAttnProcessor2_0,
|
6042
|
+
FluxIPAdapterJointAttnProcessor2_0,
|
4300
6043
|
)
|
4301
6044
|
|
4302
6045
|
AttentionProcessor = Union[
|
4303
6046
|
AttnProcessor,
|
4304
|
-
|
4305
|
-
FusedAttnProcessor2_0,
|
4306
|
-
XFormersAttnProcessor,
|
4307
|
-
SlicedAttnProcessor,
|
6047
|
+
CustomDiffusionAttnProcessor,
|
4308
6048
|
AttnAddedKVProcessor,
|
4309
|
-
SlicedAttnAddedKVProcessor,
|
4310
6049
|
AttnAddedKVProcessor2_0,
|
6050
|
+
JointAttnProcessor2_0,
|
6051
|
+
PAGJointAttnProcessor2_0,
|
6052
|
+
PAGCFGJointAttnProcessor2_0,
|
6053
|
+
FusedJointAttnProcessor2_0,
|
6054
|
+
AllegroAttnProcessor2_0,
|
6055
|
+
AuraFlowAttnProcessor2_0,
|
6056
|
+
FusedAuraFlowAttnProcessor2_0,
|
6057
|
+
FluxAttnProcessor2_0,
|
6058
|
+
FluxAttnProcessor2_0_NPU,
|
6059
|
+
FusedFluxAttnProcessor2_0,
|
6060
|
+
FusedFluxAttnProcessor2_0_NPU,
|
6061
|
+
CogVideoXAttnProcessor2_0,
|
6062
|
+
FusedCogVideoXAttnProcessor2_0,
|
4311
6063
|
XFormersAttnAddedKVProcessor,
|
4312
|
-
|
6064
|
+
XFormersAttnProcessor,
|
6065
|
+
XLAFlashAttnProcessor2_0,
|
6066
|
+
AttnProcessorNPU,
|
6067
|
+
AttnProcessor2_0,
|
6068
|
+
MochiVaeAttnProcessor2_0,
|
6069
|
+
MochiAttnProcessor2_0,
|
6070
|
+
StableAudioAttnProcessor2_0,
|
6071
|
+
HunyuanAttnProcessor2_0,
|
6072
|
+
FusedHunyuanAttnProcessor2_0,
|
6073
|
+
PAGHunyuanAttnProcessor2_0,
|
6074
|
+
PAGCFGHunyuanAttnProcessor2_0,
|
6075
|
+
LuminaAttnProcessor2_0,
|
6076
|
+
FusedAttnProcessor2_0,
|
4313
6077
|
CustomDiffusionXFormersAttnProcessor,
|
4314
6078
|
CustomDiffusionAttnProcessor2_0,
|
4315
|
-
|
6079
|
+
SlicedAttnProcessor,
|
6080
|
+
SlicedAttnAddedKVProcessor,
|
6081
|
+
SanaLinearAttnProcessor2_0,
|
6082
|
+
PAGCFGSanaLinearAttnProcessor2_0,
|
6083
|
+
PAGIdentitySanaLinearAttnProcessor2_0,
|
6084
|
+
SanaMultiscaleLinearAttention,
|
6085
|
+
SanaMultiscaleAttnProcessor2_0,
|
6086
|
+
SanaMultiscaleAttentionProjection,
|
6087
|
+
IPAdapterAttnProcessor,
|
6088
|
+
IPAdapterAttnProcessor2_0,
|
6089
|
+
IPAdapterXFormersAttnProcessor,
|
6090
|
+
SD3IPAdapterJointAttnProcessor2_0,
|
4316
6091
|
PAGIdentitySelfAttnProcessor2_0,
|
4317
|
-
|
4318
|
-
|
6092
|
+
PAGCFGIdentitySelfAttnProcessor2_0,
|
6093
|
+
LoRAAttnProcessor,
|
6094
|
+
LoRAAttnProcessor2_0,
|
6095
|
+
LoRAXFormersAttnProcessor,
|
6096
|
+
LoRAAttnAddedKVProcessor,
|
4319
6097
|
]
|