diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
  import inspect
15
15
  import math
16
- from importlib import import_module
17
- from typing import Callable, List, Optional, Union
16
+ from typing import Callable, List, Optional, Tuple, Union
18
17
 
19
18
  import torch
20
19
  import torch.nn.functional as F
@@ -23,8 +22,7 @@ from torch import nn
23
22
  from ..image_processor import IPAdapterMaskProcessor
24
23
  from ..utils import deprecate, logging
25
24
  from ..utils.import_utils import is_torch_npu_available, is_xformers_available
26
- from ..utils.torch_utils import maybe_allow_in_graph
27
- from .lora import LoRALinearLayer
25
+ from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
28
26
 
29
27
 
30
28
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -51,6 +49,10 @@ class Attention(nn.Module):
51
49
  The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
52
50
  heads (`int`, *optional*, defaults to 8):
53
51
  The number of heads to use for multi-head attention.
52
+ kv_heads (`int`, *optional*, defaults to `None`):
53
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
54
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
55
+ Query Attention (MQA) otherwise GQA is used.
54
56
  dim_head (`int`, *optional*, defaults to 64):
55
57
  The number of channels in each head.
56
58
  dropout (`float`, *optional*, defaults to 0.0):
@@ -96,6 +98,7 @@ class Attention(nn.Module):
96
98
  query_dim: int,
97
99
  cross_attention_dim: Optional[int] = None,
98
100
  heads: int = 8,
101
+ kv_heads: Optional[int] = None,
99
102
  dim_head: int = 64,
100
103
  dropout: float = 0.0,
101
104
  bias: bool = False,
@@ -105,6 +108,7 @@ class Attention(nn.Module):
105
108
  cross_attention_norm_num_groups: int = 32,
106
109
  qk_norm: Optional[str] = None,
107
110
  added_kv_proj_dim: Optional[int] = None,
111
+ added_proj_bias: Optional[bool] = True,
108
112
  norm_num_groups: Optional[int] = None,
109
113
  spatial_norm_dim: Optional[int] = None,
110
114
  out_bias: bool = True,
@@ -117,9 +121,15 @@ class Attention(nn.Module):
117
121
  processor: Optional["AttnProcessor"] = None,
118
122
  out_dim: int = None,
119
123
  context_pre_only=None,
124
+ pre_only=False,
120
125
  ):
121
126
  super().__init__()
127
+
128
+ # To prevent circular import.
129
+ from .normalization import FP32LayerNorm, RMSNorm
130
+
122
131
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
132
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
123
133
  self.query_dim = query_dim
124
134
  self.use_bias = bias
125
135
  self.is_cross_attention = cross_attention_dim is not None
@@ -132,6 +142,7 @@ class Attention(nn.Module):
132
142
  self.fused_projections = False
133
143
  self.out_dim = out_dim if out_dim is not None else query_dim
134
144
  self.context_pre_only = context_pre_only
145
+ self.pre_only = pre_only
135
146
 
136
147
  # we make use of this private variable to know whether this class is loaded
137
148
  # with an deprecated state dict so that we can convert it on the fly
@@ -170,6 +181,16 @@ class Attention(nn.Module):
170
181
  elif qk_norm == "layer_norm":
171
182
  self.norm_q = nn.LayerNorm(dim_head, eps=eps)
172
183
  self.norm_k = nn.LayerNorm(dim_head, eps=eps)
184
+ elif qk_norm == "fp32_layer_norm":
185
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
187
+ elif qk_norm == "layer_norm_across_heads":
188
+ # Lumina applys qk norm across all heads
189
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
190
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
191
+ elif qk_norm == "rms_norm":
192
+ self.norm_q = RMSNorm(dim_head, eps=eps)
193
+ self.norm_k = RMSNorm(dim_head, eps=eps)
173
194
  else:
174
195
  raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
175
196
 
@@ -200,25 +221,38 @@ class Attention(nn.Module):
200
221
 
201
222
  if not self.only_cross_attention:
202
223
  # only relevant for the `AddedKVProcessor` classes
203
- self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
204
- self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
224
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
225
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
205
226
  else:
206
227
  self.to_k = None
207
228
  self.to_v = None
208
229
 
230
+ self.added_proj_bias = added_proj_bias
209
231
  if self.added_kv_proj_dim is not None:
210
- self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
211
- self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
232
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
233
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
212
234
  if self.context_pre_only is not None:
213
- self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
235
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
214
236
 
215
- self.to_out = nn.ModuleList([])
216
- self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
217
- self.to_out.append(nn.Dropout(dropout))
237
+ if not self.pre_only:
238
+ self.to_out = nn.ModuleList([])
239
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
240
+ self.to_out.append(nn.Dropout(dropout))
218
241
 
219
242
  if self.context_pre_only is not None and not self.context_pre_only:
220
243
  self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
221
244
 
245
+ if qk_norm is not None and added_kv_proj_dim is not None:
246
+ if qk_norm == "fp32_layer_norm":
247
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
248
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
249
+ elif qk_norm == "rms_norm":
250
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
251
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
252
+ else:
253
+ self.norm_added_q = None
254
+ self.norm_added_k = None
255
+
222
256
  # set attention processor
223
257
  # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
224
258
  # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
@@ -259,10 +293,6 @@ class Attention(nn.Module):
259
293
  The attention operation to use. Defaults to `None` which uses the default attention operation from
260
294
  `xformers`.
261
295
  """
262
- is_lora = hasattr(self, "processor") and isinstance(
263
- self.processor,
264
- LORA_ATTENTION_PROCESSORS,
265
- )
266
296
  is_custom_diffusion = hasattr(self, "processor") and isinstance(
267
297
  self.processor,
268
298
  (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
@@ -274,14 +304,13 @@ class Attention(nn.Module):
274
304
  AttnAddedKVProcessor2_0,
275
305
  SlicedAttnAddedKVProcessor,
276
306
  XFormersAttnAddedKVProcessor,
277
- LoRAAttnAddedKVProcessor,
278
307
  ),
279
308
  )
280
309
 
281
310
  if use_memory_efficient_attention_xformers:
282
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
311
+ if is_added_kv_processor and is_custom_diffusion:
283
312
  raise NotImplementedError(
284
- f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
313
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
285
314
  )
286
315
  if not is_xformers_available():
287
316
  raise ModuleNotFoundError(
@@ -307,18 +336,7 @@ class Attention(nn.Module):
307
336
  except Exception as e:
308
337
  raise e
309
338
 
310
- if is_lora:
311
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
312
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
313
- processor = LoRAXFormersAttnProcessor(
314
- hidden_size=self.processor.hidden_size,
315
- cross_attention_dim=self.processor.cross_attention_dim,
316
- rank=self.processor.rank,
317
- attention_op=attention_op,
318
- )
319
- processor.load_state_dict(self.processor.state_dict())
320
- processor.to(self.processor.to_q_lora.up.weight.device)
321
- elif is_custom_diffusion:
339
+ if is_custom_diffusion:
322
340
  processor = CustomDiffusionXFormersAttnProcessor(
323
341
  train_kv=self.processor.train_kv,
324
342
  train_q_out=self.processor.train_q_out,
@@ -341,18 +359,7 @@ class Attention(nn.Module):
341
359
  else:
342
360
  processor = XFormersAttnProcessor(attention_op=attention_op)
343
361
  else:
344
- if is_lora:
345
- attn_processor_class = (
346
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
347
- )
348
- processor = attn_processor_class(
349
- hidden_size=self.processor.hidden_size,
350
- cross_attention_dim=self.processor.cross_attention_dim,
351
- rank=self.processor.rank,
352
- )
353
- processor.load_state_dict(self.processor.state_dict())
354
- processor.to(self.processor.to_q_lora.up.weight.device)
355
- elif is_custom_diffusion:
362
+ if is_custom_diffusion:
356
363
  attn_processor_class = (
357
364
  CustomDiffusionAttnProcessor2_0
358
365
  if hasattr(F, "scaled_dot_product_attention")
@@ -442,82 +449,6 @@ class Attention(nn.Module):
442
449
  if not return_deprecated_lora:
443
450
  return self.processor
444
451
 
445
- # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
446
- # serialization format for LoRA Attention Processors. It should be deleted once the integration
447
- # with PEFT is completed.
448
- is_lora_activated = {
449
- name: module.lora_layer is not None
450
- for name, module in self.named_modules()
451
- if hasattr(module, "lora_layer")
452
- }
453
-
454
- # 1. if no layer has a LoRA activated we can return the processor as usual
455
- if not any(is_lora_activated.values()):
456
- return self.processor
457
-
458
- # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
459
- is_lora_activated.pop("add_k_proj", None)
460
- is_lora_activated.pop("add_v_proj", None)
461
- # 2. else it is not possible that only some layers have LoRA activated
462
- if not all(is_lora_activated.values()):
463
- raise ValueError(
464
- f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
465
- )
466
-
467
- # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
468
- non_lora_processor_cls_name = self.processor.__class__.__name__
469
- lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
470
-
471
- hidden_size = self.inner_dim
472
-
473
- # now create a LoRA attention processor from the LoRA layers
474
- if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
475
- kwargs = {
476
- "cross_attention_dim": self.cross_attention_dim,
477
- "rank": self.to_q.lora_layer.rank,
478
- "network_alpha": self.to_q.lora_layer.network_alpha,
479
- "q_rank": self.to_q.lora_layer.rank,
480
- "q_hidden_size": self.to_q.lora_layer.out_features,
481
- "k_rank": self.to_k.lora_layer.rank,
482
- "k_hidden_size": self.to_k.lora_layer.out_features,
483
- "v_rank": self.to_v.lora_layer.rank,
484
- "v_hidden_size": self.to_v.lora_layer.out_features,
485
- "out_rank": self.to_out[0].lora_layer.rank,
486
- "out_hidden_size": self.to_out[0].lora_layer.out_features,
487
- }
488
-
489
- if hasattr(self.processor, "attention_op"):
490
- kwargs["attention_op"] = self.processor.attention_op
491
-
492
- lora_processor = lora_processor_cls(hidden_size, **kwargs)
493
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
494
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
495
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
496
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
497
- elif lora_processor_cls == LoRAAttnAddedKVProcessor:
498
- lora_processor = lora_processor_cls(
499
- hidden_size,
500
- cross_attention_dim=self.add_k_proj.weight.shape[0],
501
- rank=self.to_q.lora_layer.rank,
502
- network_alpha=self.to_q.lora_layer.network_alpha,
503
- )
504
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
505
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
506
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
507
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
508
-
509
- # only save if used
510
- if self.add_k_proj.lora_layer is not None:
511
- lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
512
- lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
513
- else:
514
- lora_processor.add_k_proj_lora = None
515
- lora_processor.add_v_proj_lora = None
516
- else:
517
- raise ValueError(f"{lora_processor_cls} does not exist.")
518
-
519
- return lora_processor
520
-
521
452
  def forward(
522
453
  self,
523
454
  hidden_states: torch.Tensor,
@@ -609,7 +540,7 @@ class Attention(nn.Module):
609
540
  return tensor
610
541
 
611
542
  def get_attention_scores(
612
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
543
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
613
544
  ) -> torch.Tensor:
614
545
  r"""
615
546
  Compute the attention scores.
@@ -760,6 +691,24 @@ class Attention(nn.Module):
760
691
  concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
761
692
  self.to_kv.bias.copy_(concatenated_bias)
762
693
 
694
+ # handle added projections for SD3 and others.
695
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
696
+ concatenated_weights = torch.cat(
697
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
698
+ )
699
+ in_features = concatenated_weights.shape[1]
700
+ out_features = concatenated_weights.shape[0]
701
+
702
+ self.to_added_qkv = nn.Linear(
703
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
704
+ )
705
+ self.to_added_qkv.weight.copy_(concatenated_weights)
706
+ if self.added_proj_bias:
707
+ concatenated_bias = torch.cat(
708
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
709
+ )
710
+ self.to_added_qkv.bias.copy_(concatenated_bias)
711
+
763
712
  self.fused_projections = fuse
764
713
 
765
714
 
@@ -1132,9 +1081,7 @@ class JointAttnProcessor2_0:
1132
1081
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1133
1082
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1134
1083
 
1135
- hidden_states = hidden_states = F.scaled_dot_product_attention(
1136
- query, key, value, dropout_p=0.0, is_causal=False
1137
- )
1084
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1138
1085
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1139
1086
  hidden_states = hidden_states.to(query.dtype)
1140
1087
 
@@ -1159,21 +1106,20 @@ class JointAttnProcessor2_0:
1159
1106
  return hidden_states, encoder_hidden_states
1160
1107
 
1161
1108
 
1162
- class FusedJointAttnProcessor2_0:
1109
+ class PAGJointAttnProcessor2_0:
1163
1110
  """Attention processor used typically in processing the SD3-like self-attention projections."""
1164
1111
 
1165
1112
  def __init__(self):
1166
1113
  if not hasattr(F, "scaled_dot_product_attention"):
1167
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1114
+ raise ImportError(
1115
+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1116
+ )
1168
1117
 
1169
1118
  def __call__(
1170
1119
  self,
1171
1120
  attn: Attention,
1172
1121
  hidden_states: torch.FloatTensor,
1173
1122
  encoder_hidden_states: torch.FloatTensor = None,
1174
- attention_mask: Optional[torch.FloatTensor] = None,
1175
- *args,
1176
- **kwargs,
1177
1123
  ) -> torch.FloatTensor:
1178
1124
  residual = hidden_states
1179
1125
 
@@ -1186,257 +1132,1409 @@ class FusedJointAttnProcessor2_0:
1186
1132
  batch_size, channel, height, width = encoder_hidden_states.shape
1187
1133
  encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1188
1134
 
1189
- batch_size = encoder_hidden_states.shape[0]
1135
+ # store the length of image patch sequences to create a mask that prevents interaction between patches
1136
+ # similar to making the self-attention map an identity matrix
1137
+ identity_block_size = hidden_states.shape[1]
1138
+
1139
+ # chunk
1140
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
1141
+ encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
1142
+
1143
+ ################## original path ##################
1144
+ batch_size = encoder_hidden_states_org.shape[0]
1190
1145
 
1191
1146
  # `sample` projections.
1192
- qkv = attn.to_qkv(hidden_states)
1193
- split_size = qkv.shape[-1] // 3
1194
- query, key, value = torch.split(qkv, split_size, dim=-1)
1147
+ query_org = attn.to_q(hidden_states_org)
1148
+ key_org = attn.to_k(hidden_states_org)
1149
+ value_org = attn.to_v(hidden_states_org)
1195
1150
 
1196
1151
  # `context` projections.
1197
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1198
- split_size = encoder_qkv.shape[-1] // 3
1199
- (
1200
- encoder_hidden_states_query_proj,
1201
- encoder_hidden_states_key_proj,
1202
- encoder_hidden_states_value_proj,
1203
- ) = torch.split(encoder_qkv, split_size, dim=-1)
1152
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1153
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1154
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1204
1155
 
1205
1156
  # attention
1206
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1207
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1208
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1157
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1158
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1159
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1209
1160
 
1210
- inner_dim = key.shape[-1]
1161
+ inner_dim = key_org.shape[-1]
1211
1162
  head_dim = inner_dim // attn.heads
1212
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1213
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1214
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1164
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1165
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215
1166
 
1216
- hidden_states = hidden_states = F.scaled_dot_product_attention(
1217
- query, key, value, dropout_p=0.0, is_causal=False
1167
+ hidden_states_org = F.scaled_dot_product_attention(
1168
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1218
1169
  )
1219
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1220
- hidden_states = hidden_states.to(query.dtype)
1170
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1171
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
1221
1172
 
1222
1173
  # Split the attention outputs.
1223
- hidden_states, encoder_hidden_states = (
1224
- hidden_states[:, : residual.shape[1]],
1225
- hidden_states[:, residual.shape[1] :],
1174
+ hidden_states_org, encoder_hidden_states_org = (
1175
+ hidden_states_org[:, : residual.shape[1]],
1176
+ hidden_states_org[:, residual.shape[1] :],
1226
1177
  )
1227
1178
 
1228
1179
  # linear proj
1229
- hidden_states = attn.to_out[0](hidden_states)
1180
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1230
1181
  # dropout
1231
- hidden_states = attn.to_out[1](hidden_states)
1182
+ hidden_states_org = attn.to_out[1](hidden_states_org)
1232
1183
  if not attn.context_pre_only:
1233
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1184
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1234
1185
 
1235
1186
  if input_ndim == 4:
1236
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1187
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1237
1188
  if context_input_ndim == 4:
1238
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1239
-
1240
- return hidden_states, encoder_hidden_states
1189
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1190
+ batch_size, channel, height, width
1191
+ )
1241
1192
 
1193
+ ################## perturbed path ##################
1242
1194
 
1243
- class XFormersAttnAddedKVProcessor:
1244
- r"""
1245
- Processor for implementing memory efficient attention using xFormers.
1195
+ batch_size = encoder_hidden_states_ptb.shape[0]
1246
1196
 
1247
- Args:
1248
- attention_op (`Callable`, *optional*, defaults to `None`):
1249
- The base
1250
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1251
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1252
- operator.
1253
- """
1197
+ # `sample` projections.
1198
+ query_ptb = attn.to_q(hidden_states_ptb)
1199
+ key_ptb = attn.to_k(hidden_states_ptb)
1200
+ value_ptb = attn.to_v(hidden_states_ptb)
1254
1201
 
1255
- def __init__(self, attention_op: Optional[Callable] = None):
1256
- self.attention_op = attention_op
1202
+ # `context` projections.
1203
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1204
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1205
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1257
1206
 
1258
- def __call__(
1259
- self,
1260
- attn: Attention,
1261
- hidden_states: torch.Tensor,
1262
- encoder_hidden_states: Optional[torch.Tensor] = None,
1263
- attention_mask: Optional[torch.Tensor] = None,
1264
- ) -> torch.Tensor:
1265
- residual = hidden_states
1266
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1267
- batch_size, sequence_length, _ = hidden_states.shape
1207
+ # attention
1208
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1209
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1210
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1268
1211
 
1269
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1212
+ inner_dim = key_ptb.shape[-1]
1213
+ head_dim = inner_dim // attn.heads
1214
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1216
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1270
1217
 
1271
- if encoder_hidden_states is None:
1272
- encoder_hidden_states = hidden_states
1273
- elif attn.norm_cross:
1274
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1218
+ # create a full mask with all entries set to 0
1219
+ seq_len = query_ptb.size(2)
1220
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1275
1221
 
1276
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1222
+ # set the attention value between image patches to -inf
1223
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1277
1224
 
1278
- query = attn.to_q(hidden_states)
1279
- query = attn.head_to_batch_dim(query)
1225
+ # set the diagonal of the attention value between image patches to 0
1226
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1280
1227
 
1281
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1282
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1283
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1284
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1228
+ # expand the mask to match the attention weights shape
1229
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1285
1230
 
1286
- if not attn.only_cross_attention:
1287
- key = attn.to_k(hidden_states)
1288
- value = attn.to_v(hidden_states)
1289
- key = attn.head_to_batch_dim(key)
1290
- value = attn.head_to_batch_dim(value)
1291
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1292
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1293
- else:
1294
- key = encoder_hidden_states_key_proj
1295
- value = encoder_hidden_states_value_proj
1231
+ hidden_states_ptb = F.scaled_dot_product_attention(
1232
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1233
+ )
1234
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1235
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1296
1236
 
1297
- hidden_states = xformers.ops.memory_efficient_attention(
1298
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1237
+ # split the attention outputs.
1238
+ hidden_states_ptb, encoder_hidden_states_ptb = (
1239
+ hidden_states_ptb[:, : residual.shape[1]],
1240
+ hidden_states_ptb[:, residual.shape[1] :],
1299
1241
  )
1300
- hidden_states = hidden_states.to(query.dtype)
1301
- hidden_states = attn.batch_to_head_dim(hidden_states)
1302
1242
 
1303
1243
  # linear proj
1304
- hidden_states = attn.to_out[0](hidden_states)
1244
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1305
1245
  # dropout
1306
- hidden_states = attn.to_out[1](hidden_states)
1246
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1247
+ if not attn.context_pre_only:
1248
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1307
1249
 
1308
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1309
- hidden_states = hidden_states + residual
1250
+ if input_ndim == 4:
1251
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1252
+ if context_input_ndim == 4:
1253
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1254
+ batch_size, channel, height, width
1255
+ )
1310
1256
 
1311
- return hidden_states
1257
+ ################ concat ###############
1258
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1259
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1312
1260
 
1261
+ return hidden_states, encoder_hidden_states
1313
1262
 
1314
- class XFormersAttnProcessor:
1315
- r"""
1316
- Processor for implementing memory efficient attention using xFormers.
1317
1263
 
1318
- Args:
1319
- attention_op (`Callable`, *optional*, defaults to `None`):
1320
- The base
1321
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1322
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1323
- operator.
1324
- """
1264
+ class PAGCFGJointAttnProcessor2_0:
1265
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1325
1266
 
1326
- def __init__(self, attention_op: Optional[Callable] = None):
1327
- self.attention_op = attention_op
1267
+ def __init__(self):
1268
+ if not hasattr(F, "scaled_dot_product_attention"):
1269
+ raise ImportError(
1270
+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1271
+ )
1328
1272
 
1329
1273
  def __call__(
1330
1274
  self,
1331
1275
  attn: Attention,
1332
- hidden_states: torch.Tensor,
1333
- encoder_hidden_states: Optional[torch.Tensor] = None,
1334
- attention_mask: Optional[torch.Tensor] = None,
1335
- temb: Optional[torch.Tensor] = None,
1276
+ hidden_states: torch.FloatTensor,
1277
+ encoder_hidden_states: torch.FloatTensor = None,
1278
+ attention_mask: Optional[torch.FloatTensor] = None,
1336
1279
  *args,
1337
1280
  **kwargs,
1338
- ) -> torch.Tensor:
1339
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1340
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1341
- deprecate("scale", "1.0.0", deprecation_message)
1342
-
1281
+ ) -> torch.FloatTensor:
1343
1282
  residual = hidden_states
1344
1283
 
1345
- if attn.spatial_norm is not None:
1346
- hidden_states = attn.spatial_norm(hidden_states, temb)
1347
-
1348
1284
  input_ndim = hidden_states.ndim
1349
-
1350
1285
  if input_ndim == 4:
1351
1286
  batch_size, channel, height, width = hidden_states.shape
1352
1287
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1288
+ context_input_ndim = encoder_hidden_states.ndim
1289
+ if context_input_ndim == 4:
1290
+ batch_size, channel, height, width = encoder_hidden_states.shape
1291
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1353
1292
 
1354
- batch_size, key_tokens, _ = (
1355
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1356
- )
1293
+ identity_block_size = hidden_states.shape[
1294
+ 1
1295
+ ] # patch embeddings width * height (correspond to self-attention map width or height)
1357
1296
 
1358
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1359
- if attention_mask is not None:
1360
- # expand our mask's singleton query_tokens dimension:
1361
- # [batch*heads, 1, key_tokens] ->
1362
- # [batch*heads, query_tokens, key_tokens]
1363
- # so that it can be added as a bias onto the attention scores that xformers computes:
1364
- # [batch*heads, query_tokens, key_tokens]
1365
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1366
- _, query_tokens, _ = hidden_states.shape
1367
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
1297
+ # chunk
1298
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
1299
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
1368
1300
 
1369
- if attn.group_norm is not None:
1370
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1301
+ (
1302
+ encoder_hidden_states_uncond,
1303
+ encoder_hidden_states_org,
1304
+ encoder_hidden_states_ptb,
1305
+ ) = encoder_hidden_states.chunk(3)
1306
+ encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
1371
1307
 
1372
- query = attn.to_q(hidden_states)
1308
+ ################## original path ##################
1309
+ batch_size = encoder_hidden_states_org.shape[0]
1373
1310
 
1374
- if encoder_hidden_states is None:
1375
- encoder_hidden_states = hidden_states
1376
- elif attn.norm_cross:
1377
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1311
+ # `sample` projections.
1312
+ query_org = attn.to_q(hidden_states_org)
1313
+ key_org = attn.to_k(hidden_states_org)
1314
+ value_org = attn.to_v(hidden_states_org)
1378
1315
 
1379
- key = attn.to_k(encoder_hidden_states)
1380
- value = attn.to_v(encoder_hidden_states)
1316
+ # `context` projections.
1317
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1318
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1319
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1381
1320
 
1382
- query = attn.head_to_batch_dim(query).contiguous()
1383
- key = attn.head_to_batch_dim(key).contiguous()
1384
- value = attn.head_to_batch_dim(value).contiguous()
1321
+ # attention
1322
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1323
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1324
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1385
1325
 
1386
- hidden_states = xformers.ops.memory_efficient_attention(
1387
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1326
+ inner_dim = key_org.shape[-1]
1327
+ head_dim = inner_dim // attn.heads
1328
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1329
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1330
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1331
+
1332
+ hidden_states_org = F.scaled_dot_product_attention(
1333
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1334
+ )
1335
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1336
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
1337
+
1338
+ # Split the attention outputs.
1339
+ hidden_states_org, encoder_hidden_states_org = (
1340
+ hidden_states_org[:, : residual.shape[1]],
1341
+ hidden_states_org[:, residual.shape[1] :],
1388
1342
  )
1389
- hidden_states = hidden_states.to(query.dtype)
1390
- hidden_states = attn.batch_to_head_dim(hidden_states)
1391
1343
 
1392
1344
  # linear proj
1393
- hidden_states = attn.to_out[0](hidden_states)
1345
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1394
1346
  # dropout
1395
- hidden_states = attn.to_out[1](hidden_states)
1347
+ hidden_states_org = attn.to_out[1](hidden_states_org)
1348
+ if not attn.context_pre_only:
1349
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1396
1350
 
1397
1351
  if input_ndim == 4:
1398
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1352
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1353
+ if context_input_ndim == 4:
1354
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1355
+ batch_size, channel, height, width
1356
+ )
1399
1357
 
1400
- if attn.residual_connection:
1401
- hidden_states = hidden_states + residual
1358
+ ################## perturbed path ##################
1402
1359
 
1403
- hidden_states = hidden_states / attn.rescale_output_factor
1360
+ batch_size = encoder_hidden_states_ptb.shape[0]
1404
1361
 
1405
- return hidden_states
1362
+ # `sample` projections.
1363
+ query_ptb = attn.to_q(hidden_states_ptb)
1364
+ key_ptb = attn.to_k(hidden_states_ptb)
1365
+ value_ptb = attn.to_v(hidden_states_ptb)
1406
1366
 
1367
+ # `context` projections.
1368
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1369
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1370
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1407
1371
 
1408
- class AttnProcessorNPU:
1372
+ # attention
1373
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1374
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1375
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1409
1376
 
1410
- r"""
1411
- Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
1412
- fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
1413
- not significant.
1377
+ inner_dim = key_ptb.shape[-1]
1378
+ head_dim = inner_dim // attn.heads
1379
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1380
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1381
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1414
1382
 
1415
- """
1383
+ # create a full mask with all entries set to 0
1384
+ seq_len = query_ptb.size(2)
1385
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1416
1386
 
1417
- def __init__(self):
1418
- if not is_torch_npu_available():
1419
- raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
1387
+ # set the attention value between image patches to -inf
1388
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1420
1389
 
1421
- def __call__(
1422
- self,
1423
- attn: Attention,
1424
- hidden_states: torch.Tensor,
1425
- encoder_hidden_states: Optional[torch.Tensor] = None,
1426
- attention_mask: Optional[torch.Tensor] = None,
1427
- temb: Optional[torch.Tensor] = None,
1428
- *args,
1429
- **kwargs,
1430
- ) -> torch.Tensor:
1431
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1432
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1433
- deprecate("scale", "1.0.0", deprecation_message)
1390
+ # set the diagonal of the attention value between image patches to 0
1391
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1434
1392
 
1435
- residual = hidden_states
1436
- if attn.spatial_norm is not None:
1437
- hidden_states = attn.spatial_norm(hidden_states, temb)
1393
+ # expand the mask to match the attention weights shape
1394
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1438
1395
 
1439
- input_ndim = hidden_states.ndim
1396
+ hidden_states_ptb = F.scaled_dot_product_attention(
1397
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1398
+ )
1399
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1400
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1401
+
1402
+ # split the attention outputs.
1403
+ hidden_states_ptb, encoder_hidden_states_ptb = (
1404
+ hidden_states_ptb[:, : residual.shape[1]],
1405
+ hidden_states_ptb[:, residual.shape[1] :],
1406
+ )
1407
+
1408
+ # linear proj
1409
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1410
+ # dropout
1411
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1412
+ if not attn.context_pre_only:
1413
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1414
+
1415
+ if input_ndim == 4:
1416
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1417
+ if context_input_ndim == 4:
1418
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1419
+ batch_size, channel, height, width
1420
+ )
1421
+
1422
+ ################ concat ###############
1423
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1424
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1425
+
1426
+ return hidden_states, encoder_hidden_states
1427
+
1428
+
1429
+ class FusedJointAttnProcessor2_0:
1430
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1431
+
1432
+ def __init__(self):
1433
+ if not hasattr(F, "scaled_dot_product_attention"):
1434
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1435
+
1436
+ def __call__(
1437
+ self,
1438
+ attn: Attention,
1439
+ hidden_states: torch.FloatTensor,
1440
+ encoder_hidden_states: torch.FloatTensor = None,
1441
+ attention_mask: Optional[torch.FloatTensor] = None,
1442
+ *args,
1443
+ **kwargs,
1444
+ ) -> torch.FloatTensor:
1445
+ residual = hidden_states
1446
+
1447
+ input_ndim = hidden_states.ndim
1448
+ if input_ndim == 4:
1449
+ batch_size, channel, height, width = hidden_states.shape
1450
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1451
+ context_input_ndim = encoder_hidden_states.ndim
1452
+ if context_input_ndim == 4:
1453
+ batch_size, channel, height, width = encoder_hidden_states.shape
1454
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1455
+
1456
+ batch_size = encoder_hidden_states.shape[0]
1457
+
1458
+ # `sample` projections.
1459
+ qkv = attn.to_qkv(hidden_states)
1460
+ split_size = qkv.shape[-1] // 3
1461
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1462
+
1463
+ # `context` projections.
1464
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1465
+ split_size = encoder_qkv.shape[-1] // 3
1466
+ (
1467
+ encoder_hidden_states_query_proj,
1468
+ encoder_hidden_states_key_proj,
1469
+ encoder_hidden_states_value_proj,
1470
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1471
+
1472
+ # attention
1473
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1474
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1475
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1476
+
1477
+ inner_dim = key.shape[-1]
1478
+ head_dim = inner_dim // attn.heads
1479
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1480
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1481
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1482
+
1483
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1484
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1485
+ hidden_states = hidden_states.to(query.dtype)
1486
+
1487
+ # Split the attention outputs.
1488
+ hidden_states, encoder_hidden_states = (
1489
+ hidden_states[:, : residual.shape[1]],
1490
+ hidden_states[:, residual.shape[1] :],
1491
+ )
1492
+
1493
+ # linear proj
1494
+ hidden_states = attn.to_out[0](hidden_states)
1495
+ # dropout
1496
+ hidden_states = attn.to_out[1](hidden_states)
1497
+ if not attn.context_pre_only:
1498
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1499
+
1500
+ if input_ndim == 4:
1501
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1502
+ if context_input_ndim == 4:
1503
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1504
+
1505
+ return hidden_states, encoder_hidden_states
1506
+
1507
+
1508
+ class AuraFlowAttnProcessor2_0:
1509
+ """Attention processor used typically in processing Aura Flow."""
1510
+
1511
+ def __init__(self):
1512
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1513
+ raise ImportError(
1514
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1515
+ )
1516
+
1517
+ def __call__(
1518
+ self,
1519
+ attn: Attention,
1520
+ hidden_states: torch.FloatTensor,
1521
+ encoder_hidden_states: torch.FloatTensor = None,
1522
+ *args,
1523
+ **kwargs,
1524
+ ) -> torch.FloatTensor:
1525
+ batch_size = hidden_states.shape[0]
1526
+
1527
+ # `sample` projections.
1528
+ query = attn.to_q(hidden_states)
1529
+ key = attn.to_k(hidden_states)
1530
+ value = attn.to_v(hidden_states)
1531
+
1532
+ # `context` projections.
1533
+ if encoder_hidden_states is not None:
1534
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1535
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1536
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1537
+
1538
+ # Reshape.
1539
+ inner_dim = key.shape[-1]
1540
+ head_dim = inner_dim // attn.heads
1541
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1542
+ key = key.view(batch_size, -1, attn.heads, head_dim)
1543
+ value = value.view(batch_size, -1, attn.heads, head_dim)
1544
+
1545
+ # Apply QK norm.
1546
+ if attn.norm_q is not None:
1547
+ query = attn.norm_q(query)
1548
+ if attn.norm_k is not None:
1549
+ key = attn.norm_k(key)
1550
+
1551
+ # Concatenate the projections.
1552
+ if encoder_hidden_states is not None:
1553
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1554
+ batch_size, -1, attn.heads, head_dim
1555
+ )
1556
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1557
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1558
+ batch_size, -1, attn.heads, head_dim
1559
+ )
1560
+
1561
+ if attn.norm_added_q is not None:
1562
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1563
+ if attn.norm_added_k is not None:
1564
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1565
+
1566
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1567
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1568
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1569
+
1570
+ query = query.transpose(1, 2)
1571
+ key = key.transpose(1, 2)
1572
+ value = value.transpose(1, 2)
1573
+
1574
+ # Attention.
1575
+ hidden_states = F.scaled_dot_product_attention(
1576
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1577
+ )
1578
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1579
+ hidden_states = hidden_states.to(query.dtype)
1580
+
1581
+ # Split the attention outputs.
1582
+ if encoder_hidden_states is not None:
1583
+ hidden_states, encoder_hidden_states = (
1584
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1585
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1586
+ )
1587
+
1588
+ # linear proj
1589
+ hidden_states = attn.to_out[0](hidden_states)
1590
+ # dropout
1591
+ hidden_states = attn.to_out[1](hidden_states)
1592
+ if encoder_hidden_states is not None:
1593
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1594
+
1595
+ if encoder_hidden_states is not None:
1596
+ return hidden_states, encoder_hidden_states
1597
+ else:
1598
+ return hidden_states
1599
+
1600
+
1601
+ class FusedAuraFlowAttnProcessor2_0:
1602
+ """Attention processor used typically in processing Aura Flow with fused projections."""
1603
+
1604
+ def __init__(self):
1605
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1606
+ raise ImportError(
1607
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1608
+ )
1609
+
1610
+ def __call__(
1611
+ self,
1612
+ attn: Attention,
1613
+ hidden_states: torch.FloatTensor,
1614
+ encoder_hidden_states: torch.FloatTensor = None,
1615
+ *args,
1616
+ **kwargs,
1617
+ ) -> torch.FloatTensor:
1618
+ batch_size = hidden_states.shape[0]
1619
+
1620
+ # `sample` projections.
1621
+ qkv = attn.to_qkv(hidden_states)
1622
+ split_size = qkv.shape[-1] // 3
1623
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1624
+
1625
+ # `context` projections.
1626
+ if encoder_hidden_states is not None:
1627
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1628
+ split_size = encoder_qkv.shape[-1] // 3
1629
+ (
1630
+ encoder_hidden_states_query_proj,
1631
+ encoder_hidden_states_key_proj,
1632
+ encoder_hidden_states_value_proj,
1633
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1634
+
1635
+ # Reshape.
1636
+ inner_dim = key.shape[-1]
1637
+ head_dim = inner_dim // attn.heads
1638
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1639
+ key = key.view(batch_size, -1, attn.heads, head_dim)
1640
+ value = value.view(batch_size, -1, attn.heads, head_dim)
1641
+
1642
+ # Apply QK norm.
1643
+ if attn.norm_q is not None:
1644
+ query = attn.norm_q(query)
1645
+ if attn.norm_k is not None:
1646
+ key = attn.norm_k(key)
1647
+
1648
+ # Concatenate the projections.
1649
+ if encoder_hidden_states is not None:
1650
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1651
+ batch_size, -1, attn.heads, head_dim
1652
+ )
1653
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1654
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1655
+ batch_size, -1, attn.heads, head_dim
1656
+ )
1657
+
1658
+ if attn.norm_added_q is not None:
1659
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1660
+ if attn.norm_added_k is not None:
1661
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1662
+
1663
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1664
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1665
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1666
+
1667
+ query = query.transpose(1, 2)
1668
+ key = key.transpose(1, 2)
1669
+ value = value.transpose(1, 2)
1670
+
1671
+ # Attention.
1672
+ hidden_states = F.scaled_dot_product_attention(
1673
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1674
+ )
1675
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1676
+ hidden_states = hidden_states.to(query.dtype)
1677
+
1678
+ # Split the attention outputs.
1679
+ if encoder_hidden_states is not None:
1680
+ hidden_states, encoder_hidden_states = (
1681
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1682
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1683
+ )
1684
+
1685
+ # linear proj
1686
+ hidden_states = attn.to_out[0](hidden_states)
1687
+ # dropout
1688
+ hidden_states = attn.to_out[1](hidden_states)
1689
+ if encoder_hidden_states is not None:
1690
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1691
+
1692
+ if encoder_hidden_states is not None:
1693
+ return hidden_states, encoder_hidden_states
1694
+ else:
1695
+ return hidden_states
1696
+
1697
+
1698
+ # YiYi to-do: refactor rope related functions/classes
1699
+ def apply_rope(xq, xk, freqs_cis):
1700
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
1701
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
1702
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
1703
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
1704
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
1705
+
1706
+
1707
+ class FluxSingleAttnProcessor2_0:
1708
+ r"""
1709
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710
+ """
1711
+
1712
+ def __init__(self):
1713
+ if not hasattr(F, "scaled_dot_product_attention"):
1714
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1715
+
1716
+ def __call__(
1717
+ self,
1718
+ attn: Attention,
1719
+ hidden_states: torch.Tensor,
1720
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1721
+ attention_mask: Optional[torch.FloatTensor] = None,
1722
+ image_rotary_emb: Optional[torch.Tensor] = None,
1723
+ ) -> torch.Tensor:
1724
+ input_ndim = hidden_states.ndim
1725
+
1726
+ if input_ndim == 4:
1727
+ batch_size, channel, height, width = hidden_states.shape
1728
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1729
+
1730
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1731
+
1732
+ query = attn.to_q(hidden_states)
1733
+ if encoder_hidden_states is None:
1734
+ encoder_hidden_states = hidden_states
1735
+
1736
+ key = attn.to_k(encoder_hidden_states)
1737
+ value = attn.to_v(encoder_hidden_states)
1738
+
1739
+ inner_dim = key.shape[-1]
1740
+ head_dim = inner_dim // attn.heads
1741
+
1742
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1743
+
1744
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1745
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1746
+
1747
+ if attn.norm_q is not None:
1748
+ query = attn.norm_q(query)
1749
+ if attn.norm_k is not None:
1750
+ key = attn.norm_k(key)
1751
+
1752
+ # Apply RoPE if needed
1753
+ if image_rotary_emb is not None:
1754
+ # YiYi to-do: update uising apply_rotary_emb
1755
+ # from ..embeddings import apply_rotary_emb
1756
+ # query = apply_rotary_emb(query, image_rotary_emb)
1757
+ # key = apply_rotary_emb(key, image_rotary_emb)
1758
+ query, key = apply_rope(query, key, image_rotary_emb)
1759
+
1760
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1761
+ # TODO: add support for attn.scale when we move to Torch 2.1
1762
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1763
+
1764
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1765
+ hidden_states = hidden_states.to(query.dtype)
1766
+
1767
+ if input_ndim == 4:
1768
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1769
+
1770
+ return hidden_states
1771
+
1772
+
1773
+ class FluxAttnProcessor2_0:
1774
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1775
+
1776
+ def __init__(self):
1777
+ if not hasattr(F, "scaled_dot_product_attention"):
1778
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1779
+
1780
+ def __call__(
1781
+ self,
1782
+ attn: Attention,
1783
+ hidden_states: torch.FloatTensor,
1784
+ encoder_hidden_states: torch.FloatTensor = None,
1785
+ attention_mask: Optional[torch.FloatTensor] = None,
1786
+ image_rotary_emb: Optional[torch.Tensor] = None,
1787
+ ) -> torch.FloatTensor:
1788
+ input_ndim = hidden_states.ndim
1789
+ if input_ndim == 4:
1790
+ batch_size, channel, height, width = hidden_states.shape
1791
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1792
+ context_input_ndim = encoder_hidden_states.ndim
1793
+ if context_input_ndim == 4:
1794
+ batch_size, channel, height, width = encoder_hidden_states.shape
1795
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1796
+
1797
+ batch_size = encoder_hidden_states.shape[0]
1798
+
1799
+ # `sample` projections.
1800
+ query = attn.to_q(hidden_states)
1801
+ key = attn.to_k(hidden_states)
1802
+ value = attn.to_v(hidden_states)
1803
+
1804
+ inner_dim = key.shape[-1]
1805
+ head_dim = inner_dim // attn.heads
1806
+
1807
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1808
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1809
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1810
+
1811
+ if attn.norm_q is not None:
1812
+ query = attn.norm_q(query)
1813
+ if attn.norm_k is not None:
1814
+ key = attn.norm_k(key)
1815
+
1816
+ # `context` projections.
1817
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1818
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1819
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1820
+
1821
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1822
+ batch_size, -1, attn.heads, head_dim
1823
+ ).transpose(1, 2)
1824
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1825
+ batch_size, -1, attn.heads, head_dim
1826
+ ).transpose(1, 2)
1827
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1828
+ batch_size, -1, attn.heads, head_dim
1829
+ ).transpose(1, 2)
1830
+
1831
+ if attn.norm_added_q is not None:
1832
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1833
+ if attn.norm_added_k is not None:
1834
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1835
+
1836
+ # attention
1837
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1838
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1839
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1840
+
1841
+ if image_rotary_emb is not None:
1842
+ # YiYi to-do: update uising apply_rotary_emb
1843
+ # from ..embeddings import apply_rotary_emb
1844
+ # query = apply_rotary_emb(query, image_rotary_emb)
1845
+ # key = apply_rotary_emb(key, image_rotary_emb)
1846
+ query, key = apply_rope(query, key, image_rotary_emb)
1847
+
1848
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1849
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1850
+ hidden_states = hidden_states.to(query.dtype)
1851
+
1852
+ encoder_hidden_states, hidden_states = (
1853
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1854
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1855
+ )
1856
+
1857
+ # linear proj
1858
+ hidden_states = attn.to_out[0](hidden_states)
1859
+ # dropout
1860
+ hidden_states = attn.to_out[1](hidden_states)
1861
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1862
+
1863
+ if input_ndim == 4:
1864
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1865
+ if context_input_ndim == 4:
1866
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1867
+
1868
+ return hidden_states, encoder_hidden_states
1869
+
1870
+
1871
+ class CogVideoXAttnProcessor2_0:
1872
+ r"""
1873
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
1874
+ query and key vectors, but does not include spatial normalization.
1875
+ """
1876
+
1877
+ def __init__(self):
1878
+ if not hasattr(F, "scaled_dot_product_attention"):
1879
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1880
+
1881
+ def __call__(
1882
+ self,
1883
+ attn: Attention,
1884
+ hidden_states: torch.Tensor,
1885
+ encoder_hidden_states: torch.Tensor,
1886
+ attention_mask: Optional[torch.Tensor] = None,
1887
+ image_rotary_emb: Optional[torch.Tensor] = None,
1888
+ ) -> torch.Tensor:
1889
+ text_seq_length = encoder_hidden_states.size(1)
1890
+
1891
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1892
+
1893
+ batch_size, sequence_length, _ = (
1894
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1895
+ )
1896
+
1897
+ if attention_mask is not None:
1898
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1899
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1900
+
1901
+ query = attn.to_q(hidden_states)
1902
+ key = attn.to_k(hidden_states)
1903
+ value = attn.to_v(hidden_states)
1904
+
1905
+ inner_dim = key.shape[-1]
1906
+ head_dim = inner_dim // attn.heads
1907
+
1908
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1909
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1910
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1911
+
1912
+ if attn.norm_q is not None:
1913
+ query = attn.norm_q(query)
1914
+ if attn.norm_k is not None:
1915
+ key = attn.norm_k(key)
1916
+
1917
+ # Apply RoPE if needed
1918
+ if image_rotary_emb is not None:
1919
+ from .embeddings import apply_rotary_emb
1920
+
1921
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
1922
+ if not attn.is_cross_attention:
1923
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
1924
+
1925
+ hidden_states = F.scaled_dot_product_attention(
1926
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1927
+ )
1928
+
1929
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1930
+
1931
+ # linear proj
1932
+ hidden_states = attn.to_out[0](hidden_states)
1933
+ # dropout
1934
+ hidden_states = attn.to_out[1](hidden_states)
1935
+
1936
+ encoder_hidden_states, hidden_states = hidden_states.split(
1937
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
1938
+ )
1939
+ return hidden_states, encoder_hidden_states
1940
+
1941
+
1942
+ class FusedCogVideoXAttnProcessor2_0:
1943
+ r"""
1944
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
1945
+ query and key vectors, but does not include spatial normalization.
1946
+ """
1947
+
1948
+ def __init__(self):
1949
+ if not hasattr(F, "scaled_dot_product_attention"):
1950
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1951
+
1952
+ def __call__(
1953
+ self,
1954
+ attn: Attention,
1955
+ hidden_states: torch.Tensor,
1956
+ encoder_hidden_states: torch.Tensor,
1957
+ attention_mask: Optional[torch.Tensor] = None,
1958
+ image_rotary_emb: Optional[torch.Tensor] = None,
1959
+ ) -> torch.Tensor:
1960
+ text_seq_length = encoder_hidden_states.size(1)
1961
+
1962
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1963
+
1964
+ batch_size, sequence_length, _ = (
1965
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1966
+ )
1967
+
1968
+ if attention_mask is not None:
1969
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1970
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1971
+
1972
+ qkv = attn.to_qkv(hidden_states)
1973
+ split_size = qkv.shape[-1] // 3
1974
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1975
+
1976
+ inner_dim = key.shape[-1]
1977
+ head_dim = inner_dim // attn.heads
1978
+
1979
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1980
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1981
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1982
+
1983
+ if attn.norm_q is not None:
1984
+ query = attn.norm_q(query)
1985
+ if attn.norm_k is not None:
1986
+ key = attn.norm_k(key)
1987
+
1988
+ # Apply RoPE if needed
1989
+ if image_rotary_emb is not None:
1990
+ from .embeddings import apply_rotary_emb
1991
+
1992
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
1993
+ if not attn.is_cross_attention:
1994
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
1995
+
1996
+ hidden_states = F.scaled_dot_product_attention(
1997
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1998
+ )
1999
+
2000
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2001
+
2002
+ # linear proj
2003
+ hidden_states = attn.to_out[0](hidden_states)
2004
+ # dropout
2005
+ hidden_states = attn.to_out[1](hidden_states)
2006
+
2007
+ encoder_hidden_states, hidden_states = hidden_states.split(
2008
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
2009
+ )
2010
+ return hidden_states, encoder_hidden_states
2011
+
2012
+
2013
+ class XFormersAttnAddedKVProcessor:
2014
+ r"""
2015
+ Processor for implementing memory efficient attention using xFormers.
2016
+
2017
+ Args:
2018
+ attention_op (`Callable`, *optional*, defaults to `None`):
2019
+ The base
2020
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
2021
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
2022
+ operator.
2023
+ """
2024
+
2025
+ def __init__(self, attention_op: Optional[Callable] = None):
2026
+ self.attention_op = attention_op
2027
+
2028
+ def __call__(
2029
+ self,
2030
+ attn: Attention,
2031
+ hidden_states: torch.Tensor,
2032
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2033
+ attention_mask: Optional[torch.Tensor] = None,
2034
+ ) -> torch.Tensor:
2035
+ residual = hidden_states
2036
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
2037
+ batch_size, sequence_length, _ = hidden_states.shape
2038
+
2039
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2040
+
2041
+ if encoder_hidden_states is None:
2042
+ encoder_hidden_states = hidden_states
2043
+ elif attn.norm_cross:
2044
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2045
+
2046
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2047
+
2048
+ query = attn.to_q(hidden_states)
2049
+ query = attn.head_to_batch_dim(query)
2050
+
2051
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2052
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2053
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
2054
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
2055
+
2056
+ if not attn.only_cross_attention:
2057
+ key = attn.to_k(hidden_states)
2058
+ value = attn.to_v(hidden_states)
2059
+ key = attn.head_to_batch_dim(key)
2060
+ value = attn.head_to_batch_dim(value)
2061
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
2062
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
2063
+ else:
2064
+ key = encoder_hidden_states_key_proj
2065
+ value = encoder_hidden_states_value_proj
2066
+
2067
+ hidden_states = xformers.ops.memory_efficient_attention(
2068
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
2069
+ )
2070
+ hidden_states = hidden_states.to(query.dtype)
2071
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2072
+
2073
+ # linear proj
2074
+ hidden_states = attn.to_out[0](hidden_states)
2075
+ # dropout
2076
+ hidden_states = attn.to_out[1](hidden_states)
2077
+
2078
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
2079
+ hidden_states = hidden_states + residual
2080
+
2081
+ return hidden_states
2082
+
2083
+
2084
+ class XFormersAttnProcessor:
2085
+ r"""
2086
+ Processor for implementing memory efficient attention using xFormers.
2087
+
2088
+ Args:
2089
+ attention_op (`Callable`, *optional*, defaults to `None`):
2090
+ The base
2091
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
2092
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
2093
+ operator.
2094
+ """
2095
+
2096
+ def __init__(self, attention_op: Optional[Callable] = None):
2097
+ self.attention_op = attention_op
2098
+
2099
+ def __call__(
2100
+ self,
2101
+ attn: Attention,
2102
+ hidden_states: torch.Tensor,
2103
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2104
+ attention_mask: Optional[torch.Tensor] = None,
2105
+ temb: Optional[torch.Tensor] = None,
2106
+ *args,
2107
+ **kwargs,
2108
+ ) -> torch.Tensor:
2109
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2110
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2111
+ deprecate("scale", "1.0.0", deprecation_message)
2112
+
2113
+ residual = hidden_states
2114
+
2115
+ if attn.spatial_norm is not None:
2116
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2117
+
2118
+ input_ndim = hidden_states.ndim
2119
+
2120
+ if input_ndim == 4:
2121
+ batch_size, channel, height, width = hidden_states.shape
2122
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2123
+
2124
+ batch_size, key_tokens, _ = (
2125
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2126
+ )
2127
+
2128
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
2129
+ if attention_mask is not None:
2130
+ # expand our mask's singleton query_tokens dimension:
2131
+ # [batch*heads, 1, key_tokens] ->
2132
+ # [batch*heads, query_tokens, key_tokens]
2133
+ # so that it can be added as a bias onto the attention scores that xformers computes:
2134
+ # [batch*heads, query_tokens, key_tokens]
2135
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
2136
+ _, query_tokens, _ = hidden_states.shape
2137
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
2138
+
2139
+ if attn.group_norm is not None:
2140
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2141
+
2142
+ query = attn.to_q(hidden_states)
2143
+
2144
+ if encoder_hidden_states is None:
2145
+ encoder_hidden_states = hidden_states
2146
+ elif attn.norm_cross:
2147
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2148
+
2149
+ key = attn.to_k(encoder_hidden_states)
2150
+ value = attn.to_v(encoder_hidden_states)
2151
+
2152
+ query = attn.head_to_batch_dim(query).contiguous()
2153
+ key = attn.head_to_batch_dim(key).contiguous()
2154
+ value = attn.head_to_batch_dim(value).contiguous()
2155
+
2156
+ hidden_states = xformers.ops.memory_efficient_attention(
2157
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
2158
+ )
2159
+ hidden_states = hidden_states.to(query.dtype)
2160
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2161
+
2162
+ # linear proj
2163
+ hidden_states = attn.to_out[0](hidden_states)
2164
+ # dropout
2165
+ hidden_states = attn.to_out[1](hidden_states)
2166
+
2167
+ if input_ndim == 4:
2168
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2169
+
2170
+ if attn.residual_connection:
2171
+ hidden_states = hidden_states + residual
2172
+
2173
+ hidden_states = hidden_states / attn.rescale_output_factor
2174
+
2175
+ return hidden_states
2176
+
2177
+
2178
+ class AttnProcessorNPU:
2179
+ r"""
2180
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
2181
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
2182
+ not significant.
2183
+
2184
+ """
2185
+
2186
+ def __init__(self):
2187
+ if not is_torch_npu_available():
2188
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
2189
+
2190
+ def __call__(
2191
+ self,
2192
+ attn: Attention,
2193
+ hidden_states: torch.Tensor,
2194
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2195
+ attention_mask: Optional[torch.Tensor] = None,
2196
+ temb: Optional[torch.Tensor] = None,
2197
+ *args,
2198
+ **kwargs,
2199
+ ) -> torch.Tensor:
2200
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2201
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2202
+ deprecate("scale", "1.0.0", deprecation_message)
2203
+
2204
+ residual = hidden_states
2205
+ if attn.spatial_norm is not None:
2206
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2207
+
2208
+ input_ndim = hidden_states.ndim
2209
+
2210
+ if input_ndim == 4:
2211
+ batch_size, channel, height, width = hidden_states.shape
2212
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2213
+
2214
+ batch_size, sequence_length, _ = (
2215
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2216
+ )
2217
+
2218
+ if attention_mask is not None:
2219
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2220
+ # scaled_dot_product_attention expects attention_mask shape to be
2221
+ # (batch, heads, source_length, target_length)
2222
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2223
+
2224
+ if attn.group_norm is not None:
2225
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2226
+
2227
+ query = attn.to_q(hidden_states)
2228
+
2229
+ if encoder_hidden_states is None:
2230
+ encoder_hidden_states = hidden_states
2231
+ elif attn.norm_cross:
2232
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2233
+
2234
+ key = attn.to_k(encoder_hidden_states)
2235
+ value = attn.to_v(encoder_hidden_states)
2236
+
2237
+ inner_dim = key.shape[-1]
2238
+ head_dim = inner_dim // attn.heads
2239
+
2240
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2241
+
2242
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2243
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2244
+
2245
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2246
+ if query.dtype in (torch.float16, torch.bfloat16):
2247
+ hidden_states = torch_npu.npu_fusion_attention(
2248
+ query,
2249
+ key,
2250
+ value,
2251
+ attn.heads,
2252
+ input_layout="BNSD",
2253
+ pse=None,
2254
+ atten_mask=attention_mask,
2255
+ scale=1.0 / math.sqrt(query.shape[-1]),
2256
+ pre_tockens=65536,
2257
+ next_tockens=65536,
2258
+ keep_prob=1.0,
2259
+ sync=False,
2260
+ inner_precise=0,
2261
+ )[0]
2262
+ else:
2263
+ # TODO: add support for attn.scale when we move to Torch 2.1
2264
+ hidden_states = F.scaled_dot_product_attention(
2265
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2266
+ )
2267
+
2268
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2269
+ hidden_states = hidden_states.to(query.dtype)
2270
+
2271
+ # linear proj
2272
+ hidden_states = attn.to_out[0](hidden_states)
2273
+ # dropout
2274
+ hidden_states = attn.to_out[1](hidden_states)
2275
+
2276
+ if input_ndim == 4:
2277
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2278
+
2279
+ if attn.residual_connection:
2280
+ hidden_states = hidden_states + residual
2281
+
2282
+ hidden_states = hidden_states / attn.rescale_output_factor
2283
+
2284
+ return hidden_states
2285
+
2286
+
2287
+ class AttnProcessor2_0:
2288
+ r"""
2289
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
2290
+ """
2291
+
2292
+ def __init__(self):
2293
+ if not hasattr(F, "scaled_dot_product_attention"):
2294
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2295
+
2296
+ def __call__(
2297
+ self,
2298
+ attn: Attention,
2299
+ hidden_states: torch.Tensor,
2300
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2301
+ attention_mask: Optional[torch.Tensor] = None,
2302
+ temb: Optional[torch.Tensor] = None,
2303
+ *args,
2304
+ **kwargs,
2305
+ ) -> torch.Tensor:
2306
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2307
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2308
+ deprecate("scale", "1.0.0", deprecation_message)
2309
+
2310
+ residual = hidden_states
2311
+ if attn.spatial_norm is not None:
2312
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2313
+
2314
+ input_ndim = hidden_states.ndim
2315
+
2316
+ if input_ndim == 4:
2317
+ batch_size, channel, height, width = hidden_states.shape
2318
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2319
+
2320
+ batch_size, sequence_length, _ = (
2321
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2322
+ )
2323
+
2324
+ if attention_mask is not None:
2325
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2326
+ # scaled_dot_product_attention expects attention_mask shape to be
2327
+ # (batch, heads, source_length, target_length)
2328
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2329
+
2330
+ if attn.group_norm is not None:
2331
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2332
+
2333
+ query = attn.to_q(hidden_states)
2334
+
2335
+ if encoder_hidden_states is None:
2336
+ encoder_hidden_states = hidden_states
2337
+ elif attn.norm_cross:
2338
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2339
+
2340
+ key = attn.to_k(encoder_hidden_states)
2341
+ value = attn.to_v(encoder_hidden_states)
2342
+
2343
+ inner_dim = key.shape[-1]
2344
+ head_dim = inner_dim // attn.heads
2345
+
2346
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2347
+
2348
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2349
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2350
+
2351
+ if attn.norm_q is not None:
2352
+ query = attn.norm_q(query)
2353
+ if attn.norm_k is not None:
2354
+ key = attn.norm_k(key)
2355
+
2356
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2357
+ # TODO: add support for attn.scale when we move to Torch 2.1
2358
+ hidden_states = F.scaled_dot_product_attention(
2359
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2360
+ )
2361
+
2362
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2363
+ hidden_states = hidden_states.to(query.dtype)
2364
+
2365
+ # linear proj
2366
+ hidden_states = attn.to_out[0](hidden_states)
2367
+ # dropout
2368
+ hidden_states = attn.to_out[1](hidden_states)
2369
+
2370
+ if input_ndim == 4:
2371
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2372
+
2373
+ if attn.residual_connection:
2374
+ hidden_states = hidden_states + residual
2375
+
2376
+ hidden_states = hidden_states / attn.rescale_output_factor
2377
+
2378
+ return hidden_states
2379
+
2380
+
2381
+ class StableAudioAttnProcessor2_0:
2382
+ r"""
2383
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2384
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
2385
+ """
2386
+
2387
+ def __init__(self):
2388
+ if not hasattr(F, "scaled_dot_product_attention"):
2389
+ raise ImportError(
2390
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2391
+ )
2392
+
2393
+ def apply_partial_rotary_emb(
2394
+ self,
2395
+ x: torch.Tensor,
2396
+ freqs_cis: Tuple[torch.Tensor],
2397
+ ) -> torch.Tensor:
2398
+ from .embeddings import apply_rotary_emb
2399
+
2400
+ rot_dim = freqs_cis[0].shape[-1]
2401
+ x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
2402
+
2403
+ x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
2404
+
2405
+ out = torch.cat((x_rotated, x_unrotated), dim=-1)
2406
+ return out
2407
+
2408
+ def __call__(
2409
+ self,
2410
+ attn: Attention,
2411
+ hidden_states: torch.Tensor,
2412
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2413
+ attention_mask: Optional[torch.Tensor] = None,
2414
+ rotary_emb: Optional[torch.Tensor] = None,
2415
+ ) -> torch.Tensor:
2416
+ from .embeddings import apply_rotary_emb
2417
+
2418
+ residual = hidden_states
2419
+
2420
+ input_ndim = hidden_states.ndim
2421
+
2422
+ if input_ndim == 4:
2423
+ batch_size, channel, height, width = hidden_states.shape
2424
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2425
+
2426
+ batch_size, sequence_length, _ = (
2427
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2428
+ )
2429
+
2430
+ if attention_mask is not None:
2431
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2432
+ # scaled_dot_product_attention expects attention_mask shape to be
2433
+ # (batch, heads, source_length, target_length)
2434
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2435
+
2436
+ query = attn.to_q(hidden_states)
2437
+
2438
+ if encoder_hidden_states is None:
2439
+ encoder_hidden_states = hidden_states
2440
+ elif attn.norm_cross:
2441
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2442
+
2443
+ key = attn.to_k(encoder_hidden_states)
2444
+ value = attn.to_v(encoder_hidden_states)
2445
+
2446
+ head_dim = query.shape[-1] // attn.heads
2447
+ kv_heads = key.shape[-1] // head_dim
2448
+
2449
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2450
+
2451
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
2452
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
2453
+
2454
+ if kv_heads != attn.heads:
2455
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
2456
+ heads_per_kv_head = attn.heads // kv_heads
2457
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
2458
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
2459
+
2460
+ if attn.norm_q is not None:
2461
+ query = attn.norm_q(query)
2462
+ if attn.norm_k is not None:
2463
+ key = attn.norm_k(key)
2464
+
2465
+ # Apply RoPE if needed
2466
+ if rotary_emb is not None:
2467
+ query_dtype = query.dtype
2468
+ key_dtype = key.dtype
2469
+ query = query.to(torch.float32)
2470
+ key = key.to(torch.float32)
2471
+
2472
+ rot_dim = rotary_emb[0].shape[-1]
2473
+ query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
2474
+ query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
2475
+
2476
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
2477
+
2478
+ if not attn.is_cross_attention:
2479
+ key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
2480
+ key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
2481
+
2482
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
2483
+
2484
+ query = query.to(query_dtype)
2485
+ key = key.to(key_dtype)
2486
+
2487
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2488
+ # TODO: add support for attn.scale when we move to Torch 2.1
2489
+ hidden_states = F.scaled_dot_product_attention(
2490
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2491
+ )
2492
+
2493
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2494
+ hidden_states = hidden_states.to(query.dtype)
2495
+
2496
+ # linear proj
2497
+ hidden_states = attn.to_out[0](hidden_states)
2498
+ # dropout
2499
+ hidden_states = attn.to_out[1](hidden_states)
2500
+
2501
+ if input_ndim == 4:
2502
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2503
+
2504
+ if attn.residual_connection:
2505
+ hidden_states = hidden_states + residual
2506
+
2507
+ hidden_states = hidden_states / attn.rescale_output_factor
2508
+
2509
+ return hidden_states
2510
+
2511
+
2512
+ class HunyuanAttnProcessor2_0:
2513
+ r"""
2514
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2515
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2516
+ """
2517
+
2518
+ def __init__(self):
2519
+ if not hasattr(F, "scaled_dot_product_attention"):
2520
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2521
+
2522
+ def __call__(
2523
+ self,
2524
+ attn: Attention,
2525
+ hidden_states: torch.Tensor,
2526
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2527
+ attention_mask: Optional[torch.Tensor] = None,
2528
+ temb: Optional[torch.Tensor] = None,
2529
+ image_rotary_emb: Optional[torch.Tensor] = None,
2530
+ ) -> torch.Tensor:
2531
+ from .embeddings import apply_rotary_emb
2532
+
2533
+ residual = hidden_states
2534
+ if attn.spatial_norm is not None:
2535
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2536
+
2537
+ input_ndim = hidden_states.ndim
1440
2538
 
1441
2539
  if input_ndim == 4:
1442
2540
  batch_size, channel, height, width = hidden_states.shape
@@ -1473,28 +2571,22 @@ class AttnProcessorNPU:
1473
2571
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1474
2572
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1475
2573
 
2574
+ if attn.norm_q is not None:
2575
+ query = attn.norm_q(query)
2576
+ if attn.norm_k is not None:
2577
+ key = attn.norm_k(key)
2578
+
2579
+ # Apply RoPE if needed
2580
+ if image_rotary_emb is not None:
2581
+ query = apply_rotary_emb(query, image_rotary_emb)
2582
+ if not attn.is_cross_attention:
2583
+ key = apply_rotary_emb(key, image_rotary_emb)
2584
+
1476
2585
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1477
- if query.dtype in (torch.float16, torch.bfloat16):
1478
- hidden_states = torch_npu.npu_fusion_attention(
1479
- query,
1480
- key,
1481
- value,
1482
- attn.heads,
1483
- input_layout="BNSD",
1484
- pse=None,
1485
- atten_mask=attention_mask,
1486
- scale=1.0 / math.sqrt(query.shape[-1]),
1487
- pre_tockens=65536,
1488
- next_tockens=65536,
1489
- keep_prob=1.0,
1490
- sync=False,
1491
- inner_precise=0,
1492
- )[0]
1493
- else:
1494
- # TODO: add support for attn.scale when we move to Torch 2.1
1495
- hidden_states = F.scaled_dot_product_attention(
1496
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1497
- )
2586
+ # TODO: add support for attn.scale when we move to Torch 2.1
2587
+ hidden_states = F.scaled_dot_product_attention(
2588
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2589
+ )
1498
2590
 
1499
2591
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1500
2592
  hidden_states = hidden_states.to(query.dtype)
@@ -1515,14 +2607,18 @@ class AttnProcessorNPU:
1515
2607
  return hidden_states
1516
2608
 
1517
2609
 
1518
- class AttnProcessor2_0:
2610
+ class FusedHunyuanAttnProcessor2_0:
1519
2611
  r"""
1520
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
2612
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
2613
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
2614
+ query and key vector.
1521
2615
  """
1522
2616
 
1523
2617
  def __init__(self):
1524
2618
  if not hasattr(F, "scaled_dot_product_attention"):
1525
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2619
+ raise ImportError(
2620
+ "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2621
+ )
1526
2622
 
1527
2623
  def __call__(
1528
2624
  self,
@@ -1531,12 +2627,9 @@ class AttnProcessor2_0:
1531
2627
  encoder_hidden_states: Optional[torch.Tensor] = None,
1532
2628
  attention_mask: Optional[torch.Tensor] = None,
1533
2629
  temb: Optional[torch.Tensor] = None,
1534
- *args,
1535
- **kwargs,
2630
+ image_rotary_emb: Optional[torch.Tensor] = None,
1536
2631
  ) -> torch.Tensor:
1537
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1538
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1539
- deprecate("scale", "1.0.0", deprecation_message)
2632
+ from .embeddings import apply_rotary_emb
1540
2633
 
1541
2634
  residual = hidden_states
1542
2635
  if attn.spatial_norm is not None:
@@ -1561,24 +2654,37 @@ class AttnProcessor2_0:
1561
2654
  if attn.group_norm is not None:
1562
2655
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1563
2656
 
1564
- query = attn.to_q(hidden_states)
1565
-
1566
2657
  if encoder_hidden_states is None:
1567
- encoder_hidden_states = hidden_states
1568
- elif attn.norm_cross:
1569
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2658
+ qkv = attn.to_qkv(hidden_states)
2659
+ split_size = qkv.shape[-1] // 3
2660
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2661
+ else:
2662
+ if attn.norm_cross:
2663
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2664
+ query = attn.to_q(hidden_states)
1570
2665
 
1571
- key = attn.to_k(encoder_hidden_states)
1572
- value = attn.to_v(encoder_hidden_states)
2666
+ kv = attn.to_kv(encoder_hidden_states)
2667
+ split_size = kv.shape[-1] // 2
2668
+ key, value = torch.split(kv, split_size, dim=-1)
1573
2669
 
1574
2670
  inner_dim = key.shape[-1]
1575
2671
  head_dim = inner_dim // attn.heads
1576
2672
 
1577
2673
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1578
-
1579
2674
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1580
2675
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1581
2676
 
2677
+ if attn.norm_q is not None:
2678
+ query = attn.norm_q(query)
2679
+ if attn.norm_k is not None:
2680
+ key = attn.norm_k(key)
2681
+
2682
+ # Apply RoPE if needed
2683
+ if image_rotary_emb is not None:
2684
+ query = apply_rotary_emb(query, image_rotary_emb)
2685
+ if not attn.is_cross_attention:
2686
+ key = apply_rotary_emb(key, image_rotary_emb)
2687
+
1582
2688
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1583
2689
  # TODO: add support for attn.scale when we move to Torch 2.1
1584
2690
  hidden_states = F.scaled_dot_product_attention(
@@ -1604,15 +2710,18 @@ class AttnProcessor2_0:
1604
2710
  return hidden_states
1605
2711
 
1606
2712
 
1607
- class HunyuanAttnProcessor2_0:
2713
+ class PAGHunyuanAttnProcessor2_0:
1608
2714
  r"""
1609
2715
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1610
- used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2716
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2717
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
1611
2718
  """
1612
2719
 
1613
2720
  def __init__(self):
1614
2721
  if not hasattr(F, "scaled_dot_product_attention"):
1615
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2722
+ raise ImportError(
2723
+ "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2724
+ )
1616
2725
 
1617
2726
  def __call__(
1618
2727
  self,
@@ -1635,8 +2744,12 @@ class HunyuanAttnProcessor2_0:
1635
2744
  batch_size, channel, height, width = hidden_states.shape
1636
2745
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1637
2746
 
2747
+ # chunk
2748
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
2749
+
2750
+ # 1. Original Path
1638
2751
  batch_size, sequence_length, _ = (
1639
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2752
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1640
2753
  )
1641
2754
 
1642
2755
  if attention_mask is not None:
@@ -1646,12 +2759,12 @@ class HunyuanAttnProcessor2_0:
1646
2759
  attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1647
2760
 
1648
2761
  if attn.group_norm is not None:
1649
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2762
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
1650
2763
 
1651
- query = attn.to_q(hidden_states)
2764
+ query = attn.to_q(hidden_states_org)
1652
2765
 
1653
2766
  if encoder_hidden_states is None:
1654
- encoder_hidden_states = hidden_states
2767
+ encoder_hidden_states = hidden_states_org
1655
2768
  elif attn.norm_cross:
1656
2769
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1657
2770
 
@@ -1679,25 +2792,263 @@ class HunyuanAttnProcessor2_0:
1679
2792
 
1680
2793
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1681
2794
  # TODO: add support for attn.scale when we move to Torch 2.1
1682
- hidden_states = F.scaled_dot_product_attention(
2795
+ hidden_states_org = F.scaled_dot_product_attention(
1683
2796
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1684
2797
  )
1685
2798
 
1686
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1687
- hidden_states = hidden_states.to(query.dtype)
2799
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2800
+ hidden_states_org = hidden_states_org.to(query.dtype)
1688
2801
 
1689
2802
  # linear proj
1690
- hidden_states = attn.to_out[0](hidden_states)
2803
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1691
2804
  # dropout
1692
- hidden_states = attn.to_out[1](hidden_states)
2805
+ hidden_states_org = attn.to_out[1](hidden_states_org)
2806
+
2807
+ if input_ndim == 4:
2808
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2809
+
2810
+ # 2. Perturbed Path
2811
+ if attn.group_norm is not None:
2812
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2813
+
2814
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
2815
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2816
+
2817
+ # linear proj
2818
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2819
+ # dropout
2820
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2821
+
2822
+ if input_ndim == 4:
2823
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2824
+
2825
+ # cat
2826
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2827
+
2828
+ if attn.residual_connection:
2829
+ hidden_states = hidden_states + residual
2830
+
2831
+ hidden_states = hidden_states / attn.rescale_output_factor
2832
+
2833
+ return hidden_states
2834
+
2835
+
2836
+ class PAGCFGHunyuanAttnProcessor2_0:
2837
+ r"""
2838
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2839
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2840
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
2841
+ """
2842
+
2843
+ def __init__(self):
2844
+ if not hasattr(F, "scaled_dot_product_attention"):
2845
+ raise ImportError(
2846
+ "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2847
+ )
2848
+
2849
+ def __call__(
2850
+ self,
2851
+ attn: Attention,
2852
+ hidden_states: torch.Tensor,
2853
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2854
+ attention_mask: Optional[torch.Tensor] = None,
2855
+ temb: Optional[torch.Tensor] = None,
2856
+ image_rotary_emb: Optional[torch.Tensor] = None,
2857
+ ) -> torch.Tensor:
2858
+ from .embeddings import apply_rotary_emb
2859
+
2860
+ residual = hidden_states
2861
+ if attn.spatial_norm is not None:
2862
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2863
+
2864
+ input_ndim = hidden_states.ndim
2865
+
2866
+ if input_ndim == 4:
2867
+ batch_size, channel, height, width = hidden_states.shape
2868
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2869
+
2870
+ # chunk
2871
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
2872
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
2873
+
2874
+ # 1. Original Path
2875
+ batch_size, sequence_length, _ = (
2876
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2877
+ )
2878
+
2879
+ if attention_mask is not None:
2880
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2881
+ # scaled_dot_product_attention expects attention_mask shape to be
2882
+ # (batch, heads, source_length, target_length)
2883
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2884
+
2885
+ if attn.group_norm is not None:
2886
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
2887
+
2888
+ query = attn.to_q(hidden_states_org)
2889
+
2890
+ if encoder_hidden_states is None:
2891
+ encoder_hidden_states = hidden_states_org
2892
+ elif attn.norm_cross:
2893
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2894
+
2895
+ key = attn.to_k(encoder_hidden_states)
2896
+ value = attn.to_v(encoder_hidden_states)
2897
+
2898
+ inner_dim = key.shape[-1]
2899
+ head_dim = inner_dim // attn.heads
2900
+
2901
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2902
+
2903
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2904
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2905
+
2906
+ if attn.norm_q is not None:
2907
+ query = attn.norm_q(query)
2908
+ if attn.norm_k is not None:
2909
+ key = attn.norm_k(key)
2910
+
2911
+ # Apply RoPE if needed
2912
+ if image_rotary_emb is not None:
2913
+ query = apply_rotary_emb(query, image_rotary_emb)
2914
+ if not attn.is_cross_attention:
2915
+ key = apply_rotary_emb(key, image_rotary_emb)
2916
+
2917
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2918
+ # TODO: add support for attn.scale when we move to Torch 2.1
2919
+ hidden_states_org = F.scaled_dot_product_attention(
2920
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2921
+ )
2922
+
2923
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2924
+ hidden_states_org = hidden_states_org.to(query.dtype)
2925
+
2926
+ # linear proj
2927
+ hidden_states_org = attn.to_out[0](hidden_states_org)
2928
+ # dropout
2929
+ hidden_states_org = attn.to_out[1](hidden_states_org)
2930
+
2931
+ if input_ndim == 4:
2932
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2933
+
2934
+ # 2. Perturbed Path
2935
+ if attn.group_norm is not None:
2936
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2937
+
2938
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
2939
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2940
+
2941
+ # linear proj
2942
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2943
+ # dropout
2944
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2945
+
2946
+ if input_ndim == 4:
2947
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2948
+
2949
+ # cat
2950
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2951
+
2952
+ if attn.residual_connection:
2953
+ hidden_states = hidden_states + residual
2954
+
2955
+ hidden_states = hidden_states / attn.rescale_output_factor
2956
+
2957
+ return hidden_states
2958
+
2959
+
2960
+ class LuminaAttnProcessor2_0:
2961
+ r"""
2962
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2963
+ used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2964
+ """
2965
+
2966
+ def __init__(self):
2967
+ if not hasattr(F, "scaled_dot_product_attention"):
2968
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2969
+
2970
+ def __call__(
2971
+ self,
2972
+ attn: Attention,
2973
+ hidden_states: torch.Tensor,
2974
+ encoder_hidden_states: torch.Tensor,
2975
+ attention_mask: Optional[torch.Tensor] = None,
2976
+ query_rotary_emb: Optional[torch.Tensor] = None,
2977
+ key_rotary_emb: Optional[torch.Tensor] = None,
2978
+ base_sequence_length: Optional[int] = None,
2979
+ ) -> torch.Tensor:
2980
+ from .embeddings import apply_rotary_emb
2981
+
2982
+ input_ndim = hidden_states.ndim
2983
+
2984
+ if input_ndim == 4:
2985
+ batch_size, channel, height, width = hidden_states.shape
2986
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2987
+
2988
+ batch_size, sequence_length, _ = hidden_states.shape
2989
+
2990
+ # Get Query-Key-Value Pair
2991
+ query = attn.to_q(hidden_states)
2992
+ key = attn.to_k(encoder_hidden_states)
2993
+ value = attn.to_v(encoder_hidden_states)
2994
+
2995
+ query_dim = query.shape[-1]
2996
+ inner_dim = key.shape[-1]
2997
+ head_dim = query_dim // attn.heads
2998
+ dtype = query.dtype
2999
+
3000
+ # Get key-value heads
3001
+ kv_heads = inner_dim // head_dim
3002
+
3003
+ # Apply Query-Key Norm if needed
3004
+ if attn.norm_q is not None:
3005
+ query = attn.norm_q(query)
3006
+ if attn.norm_k is not None:
3007
+ key = attn.norm_k(key)
3008
+
3009
+ query = query.view(batch_size, -1, attn.heads, head_dim)
3010
+
3011
+ key = key.view(batch_size, -1, kv_heads, head_dim)
3012
+ value = value.view(batch_size, -1, kv_heads, head_dim)
3013
+
3014
+ # Apply RoPE if needed
3015
+ if query_rotary_emb is not None:
3016
+ query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
3017
+ if key_rotary_emb is not None:
3018
+ key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
3019
+
3020
+ query, key = query.to(dtype), key.to(dtype)
3021
+
3022
+ # Apply proportional attention if true
3023
+ if key_rotary_emb is None:
3024
+ softmax_scale = None
3025
+ else:
3026
+ if base_sequence_length is not None:
3027
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
3028
+ else:
3029
+ softmax_scale = attn.scale
1693
3030
 
1694
- if input_ndim == 4:
1695
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3031
+ # perform Grouped-qurey Attention (GQA)
3032
+ n_rep = attn.heads // kv_heads
3033
+ if n_rep >= 1:
3034
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
3035
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
1696
3036
 
1697
- if attn.residual_connection:
1698
- hidden_states = hidden_states + residual
3037
+ # scaled_dot_product_attention expects attention_mask shape to be
3038
+ # (batch, heads, source_length, target_length)
3039
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
3040
+ attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
1699
3041
 
1700
- hidden_states = hidden_states / attn.rescale_output_factor
3042
+ query = query.transpose(1, 2)
3043
+ key = key.transpose(1, 2)
3044
+ value = value.transpose(1, 2)
3045
+
3046
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3047
+ # TODO: add support for attn.scale when we move to Torch 2.1
3048
+ hidden_states = F.scaled_dot_product_attention(
3049
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
3050
+ )
3051
+ hidden_states = hidden_states.transpose(1, 2).to(dtype)
1701
3052
 
1702
3053
  return hidden_states
1703
3054
 
@@ -1778,6 +3129,11 @@ class FusedAttnProcessor2_0:
1778
3129
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1779
3130
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1780
3131
 
3132
+ if attn.norm_q is not None:
3133
+ query = attn.norm_q(query)
3134
+ if attn.norm_k is not None:
3135
+ key = attn.norm_k(key)
3136
+
1781
3137
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1782
3138
  # TODO: add support for attn.scale when we move to Torch 2.1
1783
3139
  hidden_states = F.scaled_dot_product_attention(
@@ -2088,7 +3444,7 @@ class SlicedAttnProcessor:
2088
3444
  (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
2089
3445
  )
2090
3446
 
2091
- for i in range(batch_size_attention // self.slice_size):
3447
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
2092
3448
  start_idx = i * self.slice_size
2093
3449
  end_idx = (i + 1) * self.slice_size
2094
3450
 
@@ -2185,7 +3541,7 @@ class SlicedAttnAddedKVProcessor:
2185
3541
  (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
2186
3542
  )
2187
3543
 
2188
- for i in range(batch_size_attention // self.slice_size):
3544
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
2189
3545
  start_idx = i * self.slice_size
2190
3546
  end_idx = (i + 1) * self.slice_size
2191
3547
 
@@ -2241,264 +3597,6 @@ class SpatialNorm(nn.Module):
2241
3597
  return new_f
2242
3598
 
2243
3599
 
2244
- class LoRAAttnProcessor(nn.Module):
2245
- def __init__(
2246
- self,
2247
- hidden_size: int,
2248
- cross_attention_dim: Optional[int] = None,
2249
- rank: int = 4,
2250
- network_alpha: Optional[int] = None,
2251
- **kwargs,
2252
- ):
2253
- deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
2254
- deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
2255
-
2256
- super().__init__()
2257
-
2258
- self.hidden_size = hidden_size
2259
- self.cross_attention_dim = cross_attention_dim
2260
- self.rank = rank
2261
-
2262
- q_rank = kwargs.pop("q_rank", None)
2263
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2264
- q_rank = q_rank if q_rank is not None else rank
2265
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2266
-
2267
- v_rank = kwargs.pop("v_rank", None)
2268
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2269
- v_rank = v_rank if v_rank is not None else rank
2270
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2271
-
2272
- out_rank = kwargs.pop("out_rank", None)
2273
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2274
- out_rank = out_rank if out_rank is not None else rank
2275
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2276
-
2277
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2278
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2279
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2280
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2281
-
2282
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2283
- self_cls_name = self.__class__.__name__
2284
- deprecate(
2285
- self_cls_name,
2286
- "0.26.0",
2287
- (
2288
- f"Make sure use {self_cls_name[4:]} instead by setting"
2289
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2290
- " `LoraLoaderMixin.load_lora_weights`"
2291
- ),
2292
- )
2293
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2294
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2295
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2296
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2297
-
2298
- attn._modules.pop("processor")
2299
- attn.processor = AttnProcessor()
2300
- return attn.processor(attn, hidden_states, **kwargs)
2301
-
2302
-
2303
- class LoRAAttnProcessor2_0(nn.Module):
2304
- def __init__(
2305
- self,
2306
- hidden_size: int,
2307
- cross_attention_dim: Optional[int] = None,
2308
- rank: int = 4,
2309
- network_alpha: Optional[int] = None,
2310
- **kwargs,
2311
- ):
2312
- deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
2313
- deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
2314
-
2315
- super().__init__()
2316
- if not hasattr(F, "scaled_dot_product_attention"):
2317
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2318
-
2319
- self.hidden_size = hidden_size
2320
- self.cross_attention_dim = cross_attention_dim
2321
- self.rank = rank
2322
-
2323
- q_rank = kwargs.pop("q_rank", None)
2324
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2325
- q_rank = q_rank if q_rank is not None else rank
2326
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2327
-
2328
- v_rank = kwargs.pop("v_rank", None)
2329
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2330
- v_rank = v_rank if v_rank is not None else rank
2331
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2332
-
2333
- out_rank = kwargs.pop("out_rank", None)
2334
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2335
- out_rank = out_rank if out_rank is not None else rank
2336
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2337
-
2338
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2339
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2340
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2341
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2342
-
2343
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2344
- self_cls_name = self.__class__.__name__
2345
- deprecate(
2346
- self_cls_name,
2347
- "0.26.0",
2348
- (
2349
- f"Make sure use {self_cls_name[4:]} instead by setting"
2350
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2351
- " `LoraLoaderMixin.load_lora_weights`"
2352
- ),
2353
- )
2354
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2355
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2356
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2357
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2358
-
2359
- attn._modules.pop("processor")
2360
- attn.processor = AttnProcessor2_0()
2361
- return attn.processor(attn, hidden_states, **kwargs)
2362
-
2363
-
2364
- class LoRAXFormersAttnProcessor(nn.Module):
2365
- r"""
2366
- Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
2367
-
2368
- Args:
2369
- hidden_size (`int`, *optional*):
2370
- The hidden size of the attention layer.
2371
- cross_attention_dim (`int`, *optional*):
2372
- The number of channels in the `encoder_hidden_states`.
2373
- rank (`int`, defaults to 4):
2374
- The dimension of the LoRA update matrices.
2375
- attention_op (`Callable`, *optional*, defaults to `None`):
2376
- The base
2377
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
2378
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
2379
- operator.
2380
- network_alpha (`int`, *optional*):
2381
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
2382
- kwargs (`dict`):
2383
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
2384
- """
2385
-
2386
- def __init__(
2387
- self,
2388
- hidden_size: int,
2389
- cross_attention_dim: int,
2390
- rank: int = 4,
2391
- attention_op: Optional[Callable] = None,
2392
- network_alpha: Optional[int] = None,
2393
- **kwargs,
2394
- ):
2395
- super().__init__()
2396
-
2397
- self.hidden_size = hidden_size
2398
- self.cross_attention_dim = cross_attention_dim
2399
- self.rank = rank
2400
- self.attention_op = attention_op
2401
-
2402
- q_rank = kwargs.pop("q_rank", None)
2403
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2404
- q_rank = q_rank if q_rank is not None else rank
2405
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2406
-
2407
- v_rank = kwargs.pop("v_rank", None)
2408
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2409
- v_rank = v_rank if v_rank is not None else rank
2410
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2411
-
2412
- out_rank = kwargs.pop("out_rank", None)
2413
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2414
- out_rank = out_rank if out_rank is not None else rank
2415
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2416
-
2417
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2418
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2419
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2420
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2421
-
2422
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2423
- self_cls_name = self.__class__.__name__
2424
- deprecate(
2425
- self_cls_name,
2426
- "0.26.0",
2427
- (
2428
- f"Make sure use {self_cls_name[4:]} instead by setting"
2429
- "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2430
- " `LoraLoaderMixin.load_lora_weights`"
2431
- ),
2432
- )
2433
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2434
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2435
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2436
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2437
-
2438
- attn._modules.pop("processor")
2439
- attn.processor = XFormersAttnProcessor()
2440
- return attn.processor(attn, hidden_states, **kwargs)
2441
-
2442
-
2443
- class LoRAAttnAddedKVProcessor(nn.Module):
2444
- r"""
2445
- Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
2446
- encoder.
2447
-
2448
- Args:
2449
- hidden_size (`int`, *optional*):
2450
- The hidden size of the attention layer.
2451
- cross_attention_dim (`int`, *optional*, defaults to `None`):
2452
- The number of channels in the `encoder_hidden_states`.
2453
- rank (`int`, defaults to 4):
2454
- The dimension of the LoRA update matrices.
2455
- network_alpha (`int`, *optional*):
2456
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
2457
- kwargs (`dict`):
2458
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
2459
- """
2460
-
2461
- def __init__(
2462
- self,
2463
- hidden_size: int,
2464
- cross_attention_dim: Optional[int] = None,
2465
- rank: int = 4,
2466
- network_alpha: Optional[int] = None,
2467
- ):
2468
- super().__init__()
2469
-
2470
- self.hidden_size = hidden_size
2471
- self.cross_attention_dim = cross_attention_dim
2472
- self.rank = rank
2473
-
2474
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2475
- self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2476
- self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2477
- self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2478
- self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2479
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2480
-
2481
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2482
- self_cls_name = self.__class__.__name__
2483
- deprecate(
2484
- self_cls_name,
2485
- "0.26.0",
2486
- (
2487
- f"Make sure use {self_cls_name[4:]} instead by setting"
2488
- "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2489
- " `LoraLoaderMixin.load_lora_weights`"
2490
- ),
2491
- )
2492
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2493
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2494
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2495
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2496
-
2497
- attn._modules.pop("processor")
2498
- attn.processor = AttnAddedKVProcessor()
2499
- return attn.processor(attn, hidden_states, **kwargs)
2500
-
2501
-
2502
3600
  class IPAdapterAttnProcessor(nn.Module):
2503
3601
  r"""
2504
3602
  Attention processor for Multiple IP-Adapters.
@@ -2927,19 +4025,233 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2927
4025
  return hidden_states
2928
4026
 
2929
4027
 
2930
- LORA_ATTENTION_PROCESSORS = (
2931
- LoRAAttnProcessor,
2932
- LoRAAttnProcessor2_0,
2933
- LoRAXFormersAttnProcessor,
2934
- LoRAAttnAddedKVProcessor,
2935
- )
4028
+ class PAGIdentitySelfAttnProcessor2_0:
4029
+ r"""
4030
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
4031
+ PAG reference: https://arxiv.org/abs/2403.17377
4032
+ """
4033
+
4034
+ def __init__(self):
4035
+ if not hasattr(F, "scaled_dot_product_attention"):
4036
+ raise ImportError(
4037
+ "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
4038
+ )
4039
+
4040
+ def __call__(
4041
+ self,
4042
+ attn: Attention,
4043
+ hidden_states: torch.FloatTensor,
4044
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
4045
+ attention_mask: Optional[torch.FloatTensor] = None,
4046
+ temb: Optional[torch.FloatTensor] = None,
4047
+ ) -> torch.Tensor:
4048
+ residual = hidden_states
4049
+ if attn.spatial_norm is not None:
4050
+ hidden_states = attn.spatial_norm(hidden_states, temb)
4051
+
4052
+ input_ndim = hidden_states.ndim
4053
+ if input_ndim == 4:
4054
+ batch_size, channel, height, width = hidden_states.shape
4055
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4056
+
4057
+ # chunk
4058
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
4059
+
4060
+ # original path
4061
+ batch_size, sequence_length, _ = hidden_states_org.shape
4062
+
4063
+ if attention_mask is not None:
4064
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4065
+ # scaled_dot_product_attention expects attention_mask shape to be
4066
+ # (batch, heads, source_length, target_length)
4067
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
4068
+
4069
+ if attn.group_norm is not None:
4070
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
4071
+
4072
+ query = attn.to_q(hidden_states_org)
4073
+ key = attn.to_k(hidden_states_org)
4074
+ value = attn.to_v(hidden_states_org)
4075
+
4076
+ inner_dim = key.shape[-1]
4077
+ head_dim = inner_dim // attn.heads
4078
+
4079
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4080
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4081
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4082
+
4083
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
4084
+ # TODO: add support for attn.scale when we move to Torch 2.1
4085
+ hidden_states_org = F.scaled_dot_product_attention(
4086
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
4087
+ )
4088
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
4089
+ hidden_states_org = hidden_states_org.to(query.dtype)
4090
+
4091
+ # linear proj
4092
+ hidden_states_org = attn.to_out[0](hidden_states_org)
4093
+ # dropout
4094
+ hidden_states_org = attn.to_out[1](hidden_states_org)
4095
+
4096
+ if input_ndim == 4:
4097
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
4098
+
4099
+ # perturbed path (identity attention)
4100
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
4101
+
4102
+ if attn.group_norm is not None:
4103
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
4104
+
4105
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
4106
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
4107
+
4108
+ # linear proj
4109
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
4110
+ # dropout
4111
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
4112
+
4113
+ if input_ndim == 4:
4114
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
4115
+
4116
+ # cat
4117
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
4118
+
4119
+ if attn.residual_connection:
4120
+ hidden_states = hidden_states + residual
4121
+
4122
+ hidden_states = hidden_states / attn.rescale_output_factor
4123
+
4124
+ return hidden_states
4125
+
4126
+
4127
+ class PAGCFGIdentitySelfAttnProcessor2_0:
4128
+ r"""
4129
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
4130
+ PAG reference: https://arxiv.org/abs/2403.17377
4131
+ """
4132
+
4133
+ def __init__(self):
4134
+ if not hasattr(F, "scaled_dot_product_attention"):
4135
+ raise ImportError(
4136
+ "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
4137
+ )
4138
+
4139
+ def __call__(
4140
+ self,
4141
+ attn: Attention,
4142
+ hidden_states: torch.FloatTensor,
4143
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
4144
+ attention_mask: Optional[torch.FloatTensor] = None,
4145
+ temb: Optional[torch.FloatTensor] = None,
4146
+ ) -> torch.Tensor:
4147
+ residual = hidden_states
4148
+ if attn.spatial_norm is not None:
4149
+ hidden_states = attn.spatial_norm(hidden_states, temb)
4150
+
4151
+ input_ndim = hidden_states.ndim
4152
+ if input_ndim == 4:
4153
+ batch_size, channel, height, width = hidden_states.shape
4154
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4155
+
4156
+ # chunk
4157
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
4158
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
4159
+
4160
+ # original path
4161
+ batch_size, sequence_length, _ = hidden_states_org.shape
4162
+
4163
+ if attention_mask is not None:
4164
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4165
+ # scaled_dot_product_attention expects attention_mask shape to be
4166
+ # (batch, heads, source_length, target_length)
4167
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
4168
+
4169
+ if attn.group_norm is not None:
4170
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
4171
+
4172
+ query = attn.to_q(hidden_states_org)
4173
+ key = attn.to_k(hidden_states_org)
4174
+ value = attn.to_v(hidden_states_org)
4175
+
4176
+ inner_dim = key.shape[-1]
4177
+ head_dim = inner_dim // attn.heads
4178
+
4179
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4180
+
4181
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4182
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4183
+
4184
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
4185
+ # TODO: add support for attn.scale when we move to Torch 2.1
4186
+ hidden_states_org = F.scaled_dot_product_attention(
4187
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
4188
+ )
4189
+
4190
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
4191
+ hidden_states_org = hidden_states_org.to(query.dtype)
4192
+
4193
+ # linear proj
4194
+ hidden_states_org = attn.to_out[0](hidden_states_org)
4195
+ # dropout
4196
+ hidden_states_org = attn.to_out[1](hidden_states_org)
4197
+
4198
+ if input_ndim == 4:
4199
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
4200
+
4201
+ # perturbed path (identity attention)
4202
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
4203
+
4204
+ if attn.group_norm is not None:
4205
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
4206
+
4207
+ value = attn.to_v(hidden_states_ptb)
4208
+ hidden_states_ptb = value
4209
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
4210
+
4211
+ # linear proj
4212
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
4213
+ # dropout
4214
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
4215
+
4216
+ if input_ndim == 4:
4217
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
4218
+
4219
+ # cat
4220
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
4221
+
4222
+ if attn.residual_connection:
4223
+ hidden_states = hidden_states + residual
4224
+
4225
+ hidden_states = hidden_states / attn.rescale_output_factor
4226
+
4227
+ return hidden_states
4228
+
4229
+
4230
+ class LoRAAttnProcessor:
4231
+ def __init__(self):
4232
+ pass
4233
+
4234
+
4235
+ class LoRAAttnProcessor2_0:
4236
+ def __init__(self):
4237
+ pass
4238
+
4239
+
4240
+ class LoRAXFormersAttnProcessor:
4241
+ def __init__(self):
4242
+ pass
4243
+
4244
+
4245
+ class LoRAAttnAddedKVProcessor:
4246
+ def __init__(self):
4247
+ pass
4248
+
2936
4249
 
2937
4250
  ADDED_KV_ATTENTION_PROCESSORS = (
2938
4251
  AttnAddedKVProcessor,
2939
4252
  SlicedAttnAddedKVProcessor,
2940
4253
  AttnAddedKVProcessor2_0,
2941
4254
  XFormersAttnAddedKVProcessor,
2942
- LoRAAttnAddedKVProcessor,
2943
4255
  )
2944
4256
 
2945
4257
  CROSS_ATTENTION_PROCESSORS = (
@@ -2947,9 +4259,6 @@ CROSS_ATTENTION_PROCESSORS = (
2947
4259
  AttnProcessor2_0,
2948
4260
  XFormersAttnProcessor,
2949
4261
  SlicedAttnProcessor,
2950
- LoRAAttnProcessor,
2951
- LoRAAttnProcessor2_0,
2952
- LoRAXFormersAttnProcessor,
2953
4262
  IPAdapterAttnProcessor,
2954
4263
  IPAdapterAttnProcessor2_0,
2955
4264
  )
@@ -2967,9 +4276,8 @@ AttentionProcessor = Union[
2967
4276
  CustomDiffusionAttnProcessor,
2968
4277
  CustomDiffusionXFormersAttnProcessor,
2969
4278
  CustomDiffusionAttnProcessor2_0,
2970
- # deprecated
2971
- LoRAAttnProcessor,
2972
- LoRAAttnProcessor2_0,
2973
- LoRAXFormersAttnProcessor,
2974
- LoRAAttnAddedKVProcessor,
4279
+ PAGCFGIdentitySelfAttnProcessor2_0,
4280
+ PAGIdentitySelfAttnProcessor2_0,
4281
+ PAGCFGHunyuanAttnProcessor2_0,
4282
+ PAGHunyuanAttnProcessor2_0,
2975
4283
  ]