diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
  import inspect
15
15
  import math
16
- from importlib import import_module
17
- from typing import Callable, List, Optional, Union
16
+ from typing import Callable, List, Optional, Tuple, Union
18
17
 
19
18
  import torch
20
19
  import torch.nn.functional as F
@@ -23,8 +22,7 @@ from torch import nn
23
22
  from ..image_processor import IPAdapterMaskProcessor
24
23
  from ..utils import deprecate, logging
25
24
  from ..utils.import_utils import is_torch_npu_available, is_xformers_available
26
- from ..utils.torch_utils import maybe_allow_in_graph
27
- from .lora import LoRALinearLayer
25
+ from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
28
26
 
29
27
 
30
28
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -51,6 +49,10 @@ class Attention(nn.Module):
51
49
  The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
52
50
  heads (`int`, *optional*, defaults to 8):
53
51
  The number of heads to use for multi-head attention.
52
+ kv_heads (`int`, *optional*, defaults to `None`):
53
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
54
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
55
+ Query Attention (MQA) otherwise GQA is used.
54
56
  dim_head (`int`, *optional*, defaults to 64):
55
57
  The number of channels in each head.
56
58
  dropout (`float`, *optional*, defaults to 0.0):
@@ -96,6 +98,7 @@ class Attention(nn.Module):
96
98
  query_dim: int,
97
99
  cross_attention_dim: Optional[int] = None,
98
100
  heads: int = 8,
101
+ kv_heads: Optional[int] = None,
99
102
  dim_head: int = 64,
100
103
  dropout: float = 0.0,
101
104
  bias: bool = False,
@@ -105,6 +108,7 @@ class Attention(nn.Module):
105
108
  cross_attention_norm_num_groups: int = 32,
106
109
  qk_norm: Optional[str] = None,
107
110
  added_kv_proj_dim: Optional[int] = None,
111
+ added_proj_bias: Optional[bool] = True,
108
112
  norm_num_groups: Optional[int] = None,
109
113
  spatial_norm_dim: Optional[int] = None,
110
114
  out_bias: bool = True,
@@ -117,9 +121,15 @@ class Attention(nn.Module):
117
121
  processor: Optional["AttnProcessor"] = None,
118
122
  out_dim: int = None,
119
123
  context_pre_only=None,
124
+ pre_only=False,
120
125
  ):
121
126
  super().__init__()
127
+
128
+ # To prevent circular import.
129
+ from .normalization import FP32LayerNorm, RMSNorm
130
+
122
131
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
132
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
123
133
  self.query_dim = query_dim
124
134
  self.use_bias = bias
125
135
  self.is_cross_attention = cross_attention_dim is not None
@@ -132,6 +142,7 @@ class Attention(nn.Module):
132
142
  self.fused_projections = False
133
143
  self.out_dim = out_dim if out_dim is not None else query_dim
134
144
  self.context_pre_only = context_pre_only
145
+ self.pre_only = pre_only
135
146
 
136
147
  # we make use of this private variable to know whether this class is loaded
137
148
  # with an deprecated state dict so that we can convert it on the fly
@@ -170,6 +181,16 @@ class Attention(nn.Module):
170
181
  elif qk_norm == "layer_norm":
171
182
  self.norm_q = nn.LayerNorm(dim_head, eps=eps)
172
183
  self.norm_k = nn.LayerNorm(dim_head, eps=eps)
184
+ elif qk_norm == "fp32_layer_norm":
185
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
187
+ elif qk_norm == "layer_norm_across_heads":
188
+ # Lumina applys qk norm across all heads
189
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
190
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
191
+ elif qk_norm == "rms_norm":
192
+ self.norm_q = RMSNorm(dim_head, eps=eps)
193
+ self.norm_k = RMSNorm(dim_head, eps=eps)
173
194
  else:
174
195
  raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
175
196
 
@@ -200,25 +221,38 @@ class Attention(nn.Module):
200
221
 
201
222
  if not self.only_cross_attention:
202
223
  # only relevant for the `AddedKVProcessor` classes
203
- self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
204
- self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
224
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
225
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
205
226
  else:
206
227
  self.to_k = None
207
228
  self.to_v = None
208
229
 
230
+ self.added_proj_bias = added_proj_bias
209
231
  if self.added_kv_proj_dim is not None:
210
- self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
211
- self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
232
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
233
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
212
234
  if self.context_pre_only is not None:
213
- self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
235
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
214
236
 
215
- self.to_out = nn.ModuleList([])
216
- self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
217
- self.to_out.append(nn.Dropout(dropout))
237
+ if not self.pre_only:
238
+ self.to_out = nn.ModuleList([])
239
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
240
+ self.to_out.append(nn.Dropout(dropout))
218
241
 
219
242
  if self.context_pre_only is not None and not self.context_pre_only:
220
243
  self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
221
244
 
245
+ if qk_norm is not None and added_kv_proj_dim is not None:
246
+ if qk_norm == "fp32_layer_norm":
247
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
248
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
249
+ elif qk_norm == "rms_norm":
250
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
251
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
252
+ else:
253
+ self.norm_added_q = None
254
+ self.norm_added_k = None
255
+
222
256
  # set attention processor
223
257
  # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
224
258
  # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
@@ -259,10 +293,6 @@ class Attention(nn.Module):
259
293
  The attention operation to use. Defaults to `None` which uses the default attention operation from
260
294
  `xformers`.
261
295
  """
262
- is_lora = hasattr(self, "processor") and isinstance(
263
- self.processor,
264
- LORA_ATTENTION_PROCESSORS,
265
- )
266
296
  is_custom_diffusion = hasattr(self, "processor") and isinstance(
267
297
  self.processor,
268
298
  (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
@@ -274,14 +304,13 @@ class Attention(nn.Module):
274
304
  AttnAddedKVProcessor2_0,
275
305
  SlicedAttnAddedKVProcessor,
276
306
  XFormersAttnAddedKVProcessor,
277
- LoRAAttnAddedKVProcessor,
278
307
  ),
279
308
  )
280
309
 
281
310
  if use_memory_efficient_attention_xformers:
282
- if is_added_kv_processor and (is_lora or is_custom_diffusion):
311
+ if is_added_kv_processor and is_custom_diffusion:
283
312
  raise NotImplementedError(
284
- f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
313
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
285
314
  )
286
315
  if not is_xformers_available():
287
316
  raise ModuleNotFoundError(
@@ -307,18 +336,7 @@ class Attention(nn.Module):
307
336
  except Exception as e:
308
337
  raise e
309
338
 
310
- if is_lora:
311
- # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
312
- # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
313
- processor = LoRAXFormersAttnProcessor(
314
- hidden_size=self.processor.hidden_size,
315
- cross_attention_dim=self.processor.cross_attention_dim,
316
- rank=self.processor.rank,
317
- attention_op=attention_op,
318
- )
319
- processor.load_state_dict(self.processor.state_dict())
320
- processor.to(self.processor.to_q_lora.up.weight.device)
321
- elif is_custom_diffusion:
339
+ if is_custom_diffusion:
322
340
  processor = CustomDiffusionXFormersAttnProcessor(
323
341
  train_kv=self.processor.train_kv,
324
342
  train_q_out=self.processor.train_q_out,
@@ -341,18 +359,7 @@ class Attention(nn.Module):
341
359
  else:
342
360
  processor = XFormersAttnProcessor(attention_op=attention_op)
343
361
  else:
344
- if is_lora:
345
- attn_processor_class = (
346
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
347
- )
348
- processor = attn_processor_class(
349
- hidden_size=self.processor.hidden_size,
350
- cross_attention_dim=self.processor.cross_attention_dim,
351
- rank=self.processor.rank,
352
- )
353
- processor.load_state_dict(self.processor.state_dict())
354
- processor.to(self.processor.to_q_lora.up.weight.device)
355
- elif is_custom_diffusion:
362
+ if is_custom_diffusion:
356
363
  attn_processor_class = (
357
364
  CustomDiffusionAttnProcessor2_0
358
365
  if hasattr(F, "scaled_dot_product_attention")
@@ -442,82 +449,6 @@ class Attention(nn.Module):
442
449
  if not return_deprecated_lora:
443
450
  return self.processor
444
451
 
445
- # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
446
- # serialization format for LoRA Attention Processors. It should be deleted once the integration
447
- # with PEFT is completed.
448
- is_lora_activated = {
449
- name: module.lora_layer is not None
450
- for name, module in self.named_modules()
451
- if hasattr(module, "lora_layer")
452
- }
453
-
454
- # 1. if no layer has a LoRA activated we can return the processor as usual
455
- if not any(is_lora_activated.values()):
456
- return self.processor
457
-
458
- # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
459
- is_lora_activated.pop("add_k_proj", None)
460
- is_lora_activated.pop("add_v_proj", None)
461
- # 2. else it is not possible that only some layers have LoRA activated
462
- if not all(is_lora_activated.values()):
463
- raise ValueError(
464
- f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
465
- )
466
-
467
- # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
468
- non_lora_processor_cls_name = self.processor.__class__.__name__
469
- lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
470
-
471
- hidden_size = self.inner_dim
472
-
473
- # now create a LoRA attention processor from the LoRA layers
474
- if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
475
- kwargs = {
476
- "cross_attention_dim": self.cross_attention_dim,
477
- "rank": self.to_q.lora_layer.rank,
478
- "network_alpha": self.to_q.lora_layer.network_alpha,
479
- "q_rank": self.to_q.lora_layer.rank,
480
- "q_hidden_size": self.to_q.lora_layer.out_features,
481
- "k_rank": self.to_k.lora_layer.rank,
482
- "k_hidden_size": self.to_k.lora_layer.out_features,
483
- "v_rank": self.to_v.lora_layer.rank,
484
- "v_hidden_size": self.to_v.lora_layer.out_features,
485
- "out_rank": self.to_out[0].lora_layer.rank,
486
- "out_hidden_size": self.to_out[0].lora_layer.out_features,
487
- }
488
-
489
- if hasattr(self.processor, "attention_op"):
490
- kwargs["attention_op"] = self.processor.attention_op
491
-
492
- lora_processor = lora_processor_cls(hidden_size, **kwargs)
493
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
494
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
495
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
496
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
497
- elif lora_processor_cls == LoRAAttnAddedKVProcessor:
498
- lora_processor = lora_processor_cls(
499
- hidden_size,
500
- cross_attention_dim=self.add_k_proj.weight.shape[0],
501
- rank=self.to_q.lora_layer.rank,
502
- network_alpha=self.to_q.lora_layer.network_alpha,
503
- )
504
- lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
505
- lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
506
- lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
507
- lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
508
-
509
- # only save if used
510
- if self.add_k_proj.lora_layer is not None:
511
- lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
512
- lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
513
- else:
514
- lora_processor.add_k_proj_lora = None
515
- lora_processor.add_v_proj_lora = None
516
- else:
517
- raise ValueError(f"{lora_processor_cls} does not exist.")
518
-
519
- return lora_processor
520
-
521
452
  def forward(
522
453
  self,
523
454
  hidden_states: torch.Tensor,
@@ -609,7 +540,7 @@ class Attention(nn.Module):
609
540
  return tensor
610
541
 
611
542
  def get_attention_scores(
612
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
543
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
613
544
  ) -> torch.Tensor:
614
545
  r"""
615
546
  Compute the attention scores.
@@ -760,6 +691,24 @@ class Attention(nn.Module):
760
691
  concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
761
692
  self.to_kv.bias.copy_(concatenated_bias)
762
693
 
694
+ # handle added projections for SD3 and others.
695
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
696
+ concatenated_weights = torch.cat(
697
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
698
+ )
699
+ in_features = concatenated_weights.shape[1]
700
+ out_features = concatenated_weights.shape[0]
701
+
702
+ self.to_added_qkv = nn.Linear(
703
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
704
+ )
705
+ self.to_added_qkv.weight.copy_(concatenated_weights)
706
+ if self.added_proj_bias:
707
+ concatenated_bias = torch.cat(
708
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
709
+ )
710
+ self.to_added_qkv.bias.copy_(concatenated_bias)
711
+
763
712
  self.fused_projections = fuse
764
713
 
765
714
 
@@ -1132,9 +1081,7 @@ class JointAttnProcessor2_0:
1132
1081
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1133
1082
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1134
1083
 
1135
- hidden_states = hidden_states = F.scaled_dot_product_attention(
1136
- query, key, value, dropout_p=0.0, is_causal=False
1137
- )
1084
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1138
1085
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1139
1086
  hidden_states = hidden_states.to(query.dtype)
1140
1087
 
@@ -1159,21 +1106,20 @@ class JointAttnProcessor2_0:
1159
1106
  return hidden_states, encoder_hidden_states
1160
1107
 
1161
1108
 
1162
- class FusedJointAttnProcessor2_0:
1109
+ class PAGJointAttnProcessor2_0:
1163
1110
  """Attention processor used typically in processing the SD3-like self-attention projections."""
1164
1111
 
1165
1112
  def __init__(self):
1166
1113
  if not hasattr(F, "scaled_dot_product_attention"):
1167
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1114
+ raise ImportError(
1115
+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1116
+ )
1168
1117
 
1169
1118
  def __call__(
1170
1119
  self,
1171
1120
  attn: Attention,
1172
1121
  hidden_states: torch.FloatTensor,
1173
1122
  encoder_hidden_states: torch.FloatTensor = None,
1174
- attention_mask: Optional[torch.FloatTensor] = None,
1175
- *args,
1176
- **kwargs,
1177
1123
  ) -> torch.FloatTensor:
1178
1124
  residual = hidden_states
1179
1125
 
@@ -1186,257 +1132,1267 @@ class FusedJointAttnProcessor2_0:
1186
1132
  batch_size, channel, height, width = encoder_hidden_states.shape
1187
1133
  encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1188
1134
 
1189
- batch_size = encoder_hidden_states.shape[0]
1135
+ # store the length of image patch sequences to create a mask that prevents interaction between patches
1136
+ # similar to making the self-attention map an identity matrix
1137
+ identity_block_size = hidden_states.shape[1]
1138
+
1139
+ # chunk
1140
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
1141
+ encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
1142
+
1143
+ ################## original path ##################
1144
+ batch_size = encoder_hidden_states_org.shape[0]
1190
1145
 
1191
1146
  # `sample` projections.
1192
- qkv = attn.to_qkv(hidden_states)
1193
- split_size = qkv.shape[-1] // 3
1194
- query, key, value = torch.split(qkv, split_size, dim=-1)
1147
+ query_org = attn.to_q(hidden_states_org)
1148
+ key_org = attn.to_k(hidden_states_org)
1149
+ value_org = attn.to_v(hidden_states_org)
1195
1150
 
1196
1151
  # `context` projections.
1197
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1198
- split_size = encoder_qkv.shape[-1] // 3
1199
- (
1200
- encoder_hidden_states_query_proj,
1201
- encoder_hidden_states_key_proj,
1202
- encoder_hidden_states_value_proj,
1203
- ) = torch.split(encoder_qkv, split_size, dim=-1)
1152
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1153
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1154
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1204
1155
 
1205
1156
  # attention
1206
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1207
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1208
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1157
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1158
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1159
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1209
1160
 
1210
- inner_dim = key.shape[-1]
1161
+ inner_dim = key_org.shape[-1]
1211
1162
  head_dim = inner_dim // attn.heads
1212
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1213
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1214
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1163
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1164
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1165
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215
1166
 
1216
- hidden_states = hidden_states = F.scaled_dot_product_attention(
1217
- query, key, value, dropout_p=0.0, is_causal=False
1167
+ hidden_states_org = F.scaled_dot_product_attention(
1168
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1218
1169
  )
1219
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1220
- hidden_states = hidden_states.to(query.dtype)
1170
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1171
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
1221
1172
 
1222
1173
  # Split the attention outputs.
1223
- hidden_states, encoder_hidden_states = (
1224
- hidden_states[:, : residual.shape[1]],
1225
- hidden_states[:, residual.shape[1] :],
1174
+ hidden_states_org, encoder_hidden_states_org = (
1175
+ hidden_states_org[:, : residual.shape[1]],
1176
+ hidden_states_org[:, residual.shape[1] :],
1226
1177
  )
1227
1178
 
1228
1179
  # linear proj
1229
- hidden_states = attn.to_out[0](hidden_states)
1180
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1230
1181
  # dropout
1231
- hidden_states = attn.to_out[1](hidden_states)
1182
+ hidden_states_org = attn.to_out[1](hidden_states_org)
1232
1183
  if not attn.context_pre_only:
1233
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1184
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1234
1185
 
1235
1186
  if input_ndim == 4:
1236
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1187
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1237
1188
  if context_input_ndim == 4:
1238
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1239
-
1240
- return hidden_states, encoder_hidden_states
1189
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1190
+ batch_size, channel, height, width
1191
+ )
1241
1192
 
1193
+ ################## perturbed path ##################
1242
1194
 
1243
- class XFormersAttnAddedKVProcessor:
1244
- r"""
1245
- Processor for implementing memory efficient attention using xFormers.
1195
+ batch_size = encoder_hidden_states_ptb.shape[0]
1246
1196
 
1247
- Args:
1248
- attention_op (`Callable`, *optional*, defaults to `None`):
1249
- The base
1250
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1251
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1252
- operator.
1253
- """
1197
+ # `sample` projections.
1198
+ query_ptb = attn.to_q(hidden_states_ptb)
1199
+ key_ptb = attn.to_k(hidden_states_ptb)
1200
+ value_ptb = attn.to_v(hidden_states_ptb)
1254
1201
 
1255
- def __init__(self, attention_op: Optional[Callable] = None):
1256
- self.attention_op = attention_op
1202
+ # `context` projections.
1203
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1204
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1205
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1257
1206
 
1258
- def __call__(
1259
- self,
1260
- attn: Attention,
1261
- hidden_states: torch.Tensor,
1262
- encoder_hidden_states: Optional[torch.Tensor] = None,
1263
- attention_mask: Optional[torch.Tensor] = None,
1264
- ) -> torch.Tensor:
1265
- residual = hidden_states
1266
- hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1267
- batch_size, sequence_length, _ = hidden_states.shape
1207
+ # attention
1208
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1209
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1210
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1268
1211
 
1269
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1212
+ inner_dim = key_ptb.shape[-1]
1213
+ head_dim = inner_dim // attn.heads
1214
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1216
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1270
1217
 
1271
- if encoder_hidden_states is None:
1272
- encoder_hidden_states = hidden_states
1273
- elif attn.norm_cross:
1274
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1218
+ # create a full mask with all entries set to 0
1219
+ seq_len = query_ptb.size(2)
1220
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1275
1221
 
1276
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1222
+ # set the attention value between image patches to -inf
1223
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1277
1224
 
1278
- query = attn.to_q(hidden_states)
1279
- query = attn.head_to_batch_dim(query)
1225
+ # set the diagonal of the attention value between image patches to 0
1226
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1280
1227
 
1281
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1282
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1283
- encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1284
- encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1228
+ # expand the mask to match the attention weights shape
1229
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1285
1230
 
1286
- if not attn.only_cross_attention:
1287
- key = attn.to_k(hidden_states)
1288
- value = attn.to_v(hidden_states)
1289
- key = attn.head_to_batch_dim(key)
1290
- value = attn.head_to_batch_dim(value)
1291
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1292
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1293
- else:
1294
- key = encoder_hidden_states_key_proj
1295
- value = encoder_hidden_states_value_proj
1231
+ hidden_states_ptb = F.scaled_dot_product_attention(
1232
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1233
+ )
1234
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1235
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1296
1236
 
1297
- hidden_states = xformers.ops.memory_efficient_attention(
1298
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1237
+ # split the attention outputs.
1238
+ hidden_states_ptb, encoder_hidden_states_ptb = (
1239
+ hidden_states_ptb[:, : residual.shape[1]],
1240
+ hidden_states_ptb[:, residual.shape[1] :],
1299
1241
  )
1300
- hidden_states = hidden_states.to(query.dtype)
1301
- hidden_states = attn.batch_to_head_dim(hidden_states)
1302
1242
 
1303
1243
  # linear proj
1304
- hidden_states = attn.to_out[0](hidden_states)
1244
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1305
1245
  # dropout
1306
- hidden_states = attn.to_out[1](hidden_states)
1246
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1247
+ if not attn.context_pre_only:
1248
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1307
1249
 
1308
- hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1309
- hidden_states = hidden_states + residual
1250
+ if input_ndim == 4:
1251
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1252
+ if context_input_ndim == 4:
1253
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1254
+ batch_size, channel, height, width
1255
+ )
1310
1256
 
1311
- return hidden_states
1257
+ ################ concat ###############
1258
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1259
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1312
1260
 
1261
+ return hidden_states, encoder_hidden_states
1313
1262
 
1314
- class XFormersAttnProcessor:
1315
- r"""
1316
- Processor for implementing memory efficient attention using xFormers.
1317
1263
 
1318
- Args:
1319
- attention_op (`Callable`, *optional*, defaults to `None`):
1320
- The base
1321
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1322
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1323
- operator.
1324
- """
1264
+ class PAGCFGJointAttnProcessor2_0:
1265
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1325
1266
 
1326
- def __init__(self, attention_op: Optional[Callable] = None):
1327
- self.attention_op = attention_op
1267
+ def __init__(self):
1268
+ if not hasattr(F, "scaled_dot_product_attention"):
1269
+ raise ImportError(
1270
+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1271
+ )
1328
1272
 
1329
1273
  def __call__(
1330
1274
  self,
1331
1275
  attn: Attention,
1332
- hidden_states: torch.Tensor,
1333
- encoder_hidden_states: Optional[torch.Tensor] = None,
1334
- attention_mask: Optional[torch.Tensor] = None,
1335
- temb: Optional[torch.Tensor] = None,
1276
+ hidden_states: torch.FloatTensor,
1277
+ encoder_hidden_states: torch.FloatTensor = None,
1278
+ attention_mask: Optional[torch.FloatTensor] = None,
1336
1279
  *args,
1337
1280
  **kwargs,
1338
- ) -> torch.Tensor:
1339
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1340
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1341
- deprecate("scale", "1.0.0", deprecation_message)
1342
-
1281
+ ) -> torch.FloatTensor:
1343
1282
  residual = hidden_states
1344
1283
 
1345
- if attn.spatial_norm is not None:
1346
- hidden_states = attn.spatial_norm(hidden_states, temb)
1347
-
1348
1284
  input_ndim = hidden_states.ndim
1349
-
1350
1285
  if input_ndim == 4:
1351
1286
  batch_size, channel, height, width = hidden_states.shape
1352
1287
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1288
+ context_input_ndim = encoder_hidden_states.ndim
1289
+ if context_input_ndim == 4:
1290
+ batch_size, channel, height, width = encoder_hidden_states.shape
1291
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1353
1292
 
1354
- batch_size, key_tokens, _ = (
1355
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1356
- )
1293
+ identity_block_size = hidden_states.shape[
1294
+ 1
1295
+ ] # patch embeddings width * height (correspond to self-attention map width or height)
1357
1296
 
1358
- attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1359
- if attention_mask is not None:
1360
- # expand our mask's singleton query_tokens dimension:
1361
- # [batch*heads, 1, key_tokens] ->
1362
- # [batch*heads, query_tokens, key_tokens]
1363
- # so that it can be added as a bias onto the attention scores that xformers computes:
1364
- # [batch*heads, query_tokens, key_tokens]
1365
- # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1366
- _, query_tokens, _ = hidden_states.shape
1367
- attention_mask = attention_mask.expand(-1, query_tokens, -1)
1297
+ # chunk
1298
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
1299
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
1368
1300
 
1369
- if attn.group_norm is not None:
1370
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1301
+ (
1302
+ encoder_hidden_states_uncond,
1303
+ encoder_hidden_states_org,
1304
+ encoder_hidden_states_ptb,
1305
+ ) = encoder_hidden_states.chunk(3)
1306
+ encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
1371
1307
 
1372
- query = attn.to_q(hidden_states)
1308
+ ################## original path ##################
1309
+ batch_size = encoder_hidden_states_org.shape[0]
1373
1310
 
1374
- if encoder_hidden_states is None:
1375
- encoder_hidden_states = hidden_states
1376
- elif attn.norm_cross:
1377
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1311
+ # `sample` projections.
1312
+ query_org = attn.to_q(hidden_states_org)
1313
+ key_org = attn.to_k(hidden_states_org)
1314
+ value_org = attn.to_v(hidden_states_org)
1378
1315
 
1379
- key = attn.to_k(encoder_hidden_states)
1380
- value = attn.to_v(encoder_hidden_states)
1316
+ # `context` projections.
1317
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
1318
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
1319
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
1381
1320
 
1382
- query = attn.head_to_batch_dim(query).contiguous()
1383
- key = attn.head_to_batch_dim(key).contiguous()
1384
- value = attn.head_to_batch_dim(value).contiguous()
1321
+ # attention
1322
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
1323
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
1324
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
1385
1325
 
1386
- hidden_states = xformers.ops.memory_efficient_attention(
1387
- query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1326
+ inner_dim = key_org.shape[-1]
1327
+ head_dim = inner_dim // attn.heads
1328
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1329
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1330
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1331
+
1332
+ hidden_states_org = F.scaled_dot_product_attention(
1333
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
1334
+ )
1335
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1336
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
1337
+
1338
+ # Split the attention outputs.
1339
+ hidden_states_org, encoder_hidden_states_org = (
1340
+ hidden_states_org[:, : residual.shape[1]],
1341
+ hidden_states_org[:, residual.shape[1] :],
1388
1342
  )
1389
- hidden_states = hidden_states.to(query.dtype)
1390
- hidden_states = attn.batch_to_head_dim(hidden_states)
1391
1343
 
1392
1344
  # linear proj
1393
- hidden_states = attn.to_out[0](hidden_states)
1345
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1394
1346
  # dropout
1395
- hidden_states = attn.to_out[1](hidden_states)
1347
+ hidden_states_org = attn.to_out[1](hidden_states_org)
1348
+ if not attn.context_pre_only:
1349
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
1396
1350
 
1397
1351
  if input_ndim == 4:
1398
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1352
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
1353
+ if context_input_ndim == 4:
1354
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
1355
+ batch_size, channel, height, width
1356
+ )
1399
1357
 
1400
- if attn.residual_connection:
1401
- hidden_states = hidden_states + residual
1358
+ ################## perturbed path ##################
1402
1359
 
1403
- hidden_states = hidden_states / attn.rescale_output_factor
1360
+ batch_size = encoder_hidden_states_ptb.shape[0]
1404
1361
 
1405
- return hidden_states
1362
+ # `sample` projections.
1363
+ query_ptb = attn.to_q(hidden_states_ptb)
1364
+ key_ptb = attn.to_k(hidden_states_ptb)
1365
+ value_ptb = attn.to_v(hidden_states_ptb)
1406
1366
 
1367
+ # `context` projections.
1368
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
1369
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
1370
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
1407
1371
 
1408
- class AttnProcessorNPU:
1372
+ # attention
1373
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
1374
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
1375
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
1409
1376
 
1410
- r"""
1411
- Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
1412
- fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
1413
- not significant.
1377
+ inner_dim = key_ptb.shape[-1]
1378
+ head_dim = inner_dim // attn.heads
1379
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1380
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1381
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1414
1382
 
1415
- """
1383
+ # create a full mask with all entries set to 0
1384
+ seq_len = query_ptb.size(2)
1385
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
1416
1386
 
1417
- def __init__(self):
1418
- if not is_torch_npu_available():
1419
- raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
1387
+ # set the attention value between image patches to -inf
1388
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
1420
1389
 
1421
- def __call__(
1422
- self,
1423
- attn: Attention,
1424
- hidden_states: torch.Tensor,
1425
- encoder_hidden_states: Optional[torch.Tensor] = None,
1426
- attention_mask: Optional[torch.Tensor] = None,
1427
- temb: Optional[torch.Tensor] = None,
1428
- *args,
1429
- **kwargs,
1430
- ) -> torch.Tensor:
1431
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1432
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1433
- deprecate("scale", "1.0.0", deprecation_message)
1390
+ # set the diagonal of the attention value between image patches to 0
1391
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
1434
1392
 
1435
- residual = hidden_states
1436
- if attn.spatial_norm is not None:
1437
- hidden_states = attn.spatial_norm(hidden_states, temb)
1393
+ # expand the mask to match the attention weights shape
1394
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
1438
1395
 
1439
- input_ndim = hidden_states.ndim
1396
+ hidden_states_ptb = F.scaled_dot_product_attention(
1397
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
1398
+ )
1399
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1400
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
1401
+
1402
+ # split the attention outputs.
1403
+ hidden_states_ptb, encoder_hidden_states_ptb = (
1404
+ hidden_states_ptb[:, : residual.shape[1]],
1405
+ hidden_states_ptb[:, residual.shape[1] :],
1406
+ )
1407
+
1408
+ # linear proj
1409
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
1410
+ # dropout
1411
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
1412
+ if not attn.context_pre_only:
1413
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
1414
+
1415
+ if input_ndim == 4:
1416
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
1417
+ if context_input_ndim == 4:
1418
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
1419
+ batch_size, channel, height, width
1420
+ )
1421
+
1422
+ ################ concat ###############
1423
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
1424
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
1425
+
1426
+ return hidden_states, encoder_hidden_states
1427
+
1428
+
1429
+ class FusedJointAttnProcessor2_0:
1430
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1431
+
1432
+ def __init__(self):
1433
+ if not hasattr(F, "scaled_dot_product_attention"):
1434
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1435
+
1436
+ def __call__(
1437
+ self,
1438
+ attn: Attention,
1439
+ hidden_states: torch.FloatTensor,
1440
+ encoder_hidden_states: torch.FloatTensor = None,
1441
+ attention_mask: Optional[torch.FloatTensor] = None,
1442
+ *args,
1443
+ **kwargs,
1444
+ ) -> torch.FloatTensor:
1445
+ residual = hidden_states
1446
+
1447
+ input_ndim = hidden_states.ndim
1448
+ if input_ndim == 4:
1449
+ batch_size, channel, height, width = hidden_states.shape
1450
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1451
+ context_input_ndim = encoder_hidden_states.ndim
1452
+ if context_input_ndim == 4:
1453
+ batch_size, channel, height, width = encoder_hidden_states.shape
1454
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1455
+
1456
+ batch_size = encoder_hidden_states.shape[0]
1457
+
1458
+ # `sample` projections.
1459
+ qkv = attn.to_qkv(hidden_states)
1460
+ split_size = qkv.shape[-1] // 3
1461
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1462
+
1463
+ # `context` projections.
1464
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1465
+ split_size = encoder_qkv.shape[-1] // 3
1466
+ (
1467
+ encoder_hidden_states_query_proj,
1468
+ encoder_hidden_states_key_proj,
1469
+ encoder_hidden_states_value_proj,
1470
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1471
+
1472
+ # attention
1473
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1474
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1475
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1476
+
1477
+ inner_dim = key.shape[-1]
1478
+ head_dim = inner_dim // attn.heads
1479
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1480
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1481
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1482
+
1483
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1484
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1485
+ hidden_states = hidden_states.to(query.dtype)
1486
+
1487
+ # Split the attention outputs.
1488
+ hidden_states, encoder_hidden_states = (
1489
+ hidden_states[:, : residual.shape[1]],
1490
+ hidden_states[:, residual.shape[1] :],
1491
+ )
1492
+
1493
+ # linear proj
1494
+ hidden_states = attn.to_out[0](hidden_states)
1495
+ # dropout
1496
+ hidden_states = attn.to_out[1](hidden_states)
1497
+ if not attn.context_pre_only:
1498
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1499
+
1500
+ if input_ndim == 4:
1501
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1502
+ if context_input_ndim == 4:
1503
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1504
+
1505
+ return hidden_states, encoder_hidden_states
1506
+
1507
+
1508
+ class AuraFlowAttnProcessor2_0:
1509
+ """Attention processor used typically in processing Aura Flow."""
1510
+
1511
+ def __init__(self):
1512
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1513
+ raise ImportError(
1514
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1515
+ )
1516
+
1517
+ def __call__(
1518
+ self,
1519
+ attn: Attention,
1520
+ hidden_states: torch.FloatTensor,
1521
+ encoder_hidden_states: torch.FloatTensor = None,
1522
+ *args,
1523
+ **kwargs,
1524
+ ) -> torch.FloatTensor:
1525
+ batch_size = hidden_states.shape[0]
1526
+
1527
+ # `sample` projections.
1528
+ query = attn.to_q(hidden_states)
1529
+ key = attn.to_k(hidden_states)
1530
+ value = attn.to_v(hidden_states)
1531
+
1532
+ # `context` projections.
1533
+ if encoder_hidden_states is not None:
1534
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1535
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1536
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1537
+
1538
+ # Reshape.
1539
+ inner_dim = key.shape[-1]
1540
+ head_dim = inner_dim // attn.heads
1541
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1542
+ key = key.view(batch_size, -1, attn.heads, head_dim)
1543
+ value = value.view(batch_size, -1, attn.heads, head_dim)
1544
+
1545
+ # Apply QK norm.
1546
+ if attn.norm_q is not None:
1547
+ query = attn.norm_q(query)
1548
+ if attn.norm_k is not None:
1549
+ key = attn.norm_k(key)
1550
+
1551
+ # Concatenate the projections.
1552
+ if encoder_hidden_states is not None:
1553
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1554
+ batch_size, -1, attn.heads, head_dim
1555
+ )
1556
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1557
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1558
+ batch_size, -1, attn.heads, head_dim
1559
+ )
1560
+
1561
+ if attn.norm_added_q is not None:
1562
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1563
+ if attn.norm_added_k is not None:
1564
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1565
+
1566
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1567
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1568
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1569
+
1570
+ query = query.transpose(1, 2)
1571
+ key = key.transpose(1, 2)
1572
+ value = value.transpose(1, 2)
1573
+
1574
+ # Attention.
1575
+ hidden_states = F.scaled_dot_product_attention(
1576
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1577
+ )
1578
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1579
+ hidden_states = hidden_states.to(query.dtype)
1580
+
1581
+ # Split the attention outputs.
1582
+ if encoder_hidden_states is not None:
1583
+ hidden_states, encoder_hidden_states = (
1584
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1585
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1586
+ )
1587
+
1588
+ # linear proj
1589
+ hidden_states = attn.to_out[0](hidden_states)
1590
+ # dropout
1591
+ hidden_states = attn.to_out[1](hidden_states)
1592
+ if encoder_hidden_states is not None:
1593
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1594
+
1595
+ if encoder_hidden_states is not None:
1596
+ return hidden_states, encoder_hidden_states
1597
+ else:
1598
+ return hidden_states
1599
+
1600
+
1601
+ class FusedAuraFlowAttnProcessor2_0:
1602
+ """Attention processor used typically in processing Aura Flow with fused projections."""
1603
+
1604
+ def __init__(self):
1605
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1606
+ raise ImportError(
1607
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1608
+ )
1609
+
1610
+ def __call__(
1611
+ self,
1612
+ attn: Attention,
1613
+ hidden_states: torch.FloatTensor,
1614
+ encoder_hidden_states: torch.FloatTensor = None,
1615
+ *args,
1616
+ **kwargs,
1617
+ ) -> torch.FloatTensor:
1618
+ batch_size = hidden_states.shape[0]
1619
+
1620
+ # `sample` projections.
1621
+ qkv = attn.to_qkv(hidden_states)
1622
+ split_size = qkv.shape[-1] // 3
1623
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1624
+
1625
+ # `context` projections.
1626
+ if encoder_hidden_states is not None:
1627
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1628
+ split_size = encoder_qkv.shape[-1] // 3
1629
+ (
1630
+ encoder_hidden_states_query_proj,
1631
+ encoder_hidden_states_key_proj,
1632
+ encoder_hidden_states_value_proj,
1633
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1634
+
1635
+ # Reshape.
1636
+ inner_dim = key.shape[-1]
1637
+ head_dim = inner_dim // attn.heads
1638
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1639
+ key = key.view(batch_size, -1, attn.heads, head_dim)
1640
+ value = value.view(batch_size, -1, attn.heads, head_dim)
1641
+
1642
+ # Apply QK norm.
1643
+ if attn.norm_q is not None:
1644
+ query = attn.norm_q(query)
1645
+ if attn.norm_k is not None:
1646
+ key = attn.norm_k(key)
1647
+
1648
+ # Concatenate the projections.
1649
+ if encoder_hidden_states is not None:
1650
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1651
+ batch_size, -1, attn.heads, head_dim
1652
+ )
1653
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1654
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1655
+ batch_size, -1, attn.heads, head_dim
1656
+ )
1657
+
1658
+ if attn.norm_added_q is not None:
1659
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1660
+ if attn.norm_added_k is not None:
1661
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1662
+
1663
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1664
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1665
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1666
+
1667
+ query = query.transpose(1, 2)
1668
+ key = key.transpose(1, 2)
1669
+ value = value.transpose(1, 2)
1670
+
1671
+ # Attention.
1672
+ hidden_states = F.scaled_dot_product_attention(
1673
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1674
+ )
1675
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1676
+ hidden_states = hidden_states.to(query.dtype)
1677
+
1678
+ # Split the attention outputs.
1679
+ if encoder_hidden_states is not None:
1680
+ hidden_states, encoder_hidden_states = (
1681
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1682
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1683
+ )
1684
+
1685
+ # linear proj
1686
+ hidden_states = attn.to_out[0](hidden_states)
1687
+ # dropout
1688
+ hidden_states = attn.to_out[1](hidden_states)
1689
+ if encoder_hidden_states is not None:
1690
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1691
+
1692
+ if encoder_hidden_states is not None:
1693
+ return hidden_states, encoder_hidden_states
1694
+ else:
1695
+ return hidden_states
1696
+
1697
+
1698
+ # YiYi to-do: refactor rope related functions/classes
1699
+ def apply_rope(xq, xk, freqs_cis):
1700
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
1701
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
1702
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
1703
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
1704
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
1705
+
1706
+
1707
+ class FluxSingleAttnProcessor2_0:
1708
+ r"""
1709
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710
+ """
1711
+
1712
+ def __init__(self):
1713
+ if not hasattr(F, "scaled_dot_product_attention"):
1714
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1715
+
1716
+ def __call__(
1717
+ self,
1718
+ attn: Attention,
1719
+ hidden_states: torch.Tensor,
1720
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1721
+ attention_mask: Optional[torch.FloatTensor] = None,
1722
+ image_rotary_emb: Optional[torch.Tensor] = None,
1723
+ ) -> torch.Tensor:
1724
+ input_ndim = hidden_states.ndim
1725
+
1726
+ if input_ndim == 4:
1727
+ batch_size, channel, height, width = hidden_states.shape
1728
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1729
+
1730
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1731
+
1732
+ query = attn.to_q(hidden_states)
1733
+ if encoder_hidden_states is None:
1734
+ encoder_hidden_states = hidden_states
1735
+
1736
+ key = attn.to_k(encoder_hidden_states)
1737
+ value = attn.to_v(encoder_hidden_states)
1738
+
1739
+ inner_dim = key.shape[-1]
1740
+ head_dim = inner_dim // attn.heads
1741
+
1742
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1743
+
1744
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1745
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1746
+
1747
+ if attn.norm_q is not None:
1748
+ query = attn.norm_q(query)
1749
+ if attn.norm_k is not None:
1750
+ key = attn.norm_k(key)
1751
+
1752
+ # Apply RoPE if needed
1753
+ if image_rotary_emb is not None:
1754
+ # YiYi to-do: update uising apply_rotary_emb
1755
+ # from ..embeddings import apply_rotary_emb
1756
+ # query = apply_rotary_emb(query, image_rotary_emb)
1757
+ # key = apply_rotary_emb(key, image_rotary_emb)
1758
+ query, key = apply_rope(query, key, image_rotary_emb)
1759
+
1760
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1761
+ # TODO: add support for attn.scale when we move to Torch 2.1
1762
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1763
+
1764
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1765
+ hidden_states = hidden_states.to(query.dtype)
1766
+
1767
+ if input_ndim == 4:
1768
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1769
+
1770
+ return hidden_states
1771
+
1772
+
1773
+ class FluxAttnProcessor2_0:
1774
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1775
+
1776
+ def __init__(self):
1777
+ if not hasattr(F, "scaled_dot_product_attention"):
1778
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1779
+
1780
+ def __call__(
1781
+ self,
1782
+ attn: Attention,
1783
+ hidden_states: torch.FloatTensor,
1784
+ encoder_hidden_states: torch.FloatTensor = None,
1785
+ attention_mask: Optional[torch.FloatTensor] = None,
1786
+ image_rotary_emb: Optional[torch.Tensor] = None,
1787
+ ) -> torch.FloatTensor:
1788
+ input_ndim = hidden_states.ndim
1789
+ if input_ndim == 4:
1790
+ batch_size, channel, height, width = hidden_states.shape
1791
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1792
+ context_input_ndim = encoder_hidden_states.ndim
1793
+ if context_input_ndim == 4:
1794
+ batch_size, channel, height, width = encoder_hidden_states.shape
1795
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1796
+
1797
+ batch_size = encoder_hidden_states.shape[0]
1798
+
1799
+ # `sample` projections.
1800
+ query = attn.to_q(hidden_states)
1801
+ key = attn.to_k(hidden_states)
1802
+ value = attn.to_v(hidden_states)
1803
+
1804
+ inner_dim = key.shape[-1]
1805
+ head_dim = inner_dim // attn.heads
1806
+
1807
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1808
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1809
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1810
+
1811
+ if attn.norm_q is not None:
1812
+ query = attn.norm_q(query)
1813
+ if attn.norm_k is not None:
1814
+ key = attn.norm_k(key)
1815
+
1816
+ # `context` projections.
1817
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1818
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1819
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1820
+
1821
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1822
+ batch_size, -1, attn.heads, head_dim
1823
+ ).transpose(1, 2)
1824
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1825
+ batch_size, -1, attn.heads, head_dim
1826
+ ).transpose(1, 2)
1827
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1828
+ batch_size, -1, attn.heads, head_dim
1829
+ ).transpose(1, 2)
1830
+
1831
+ if attn.norm_added_q is not None:
1832
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1833
+ if attn.norm_added_k is not None:
1834
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1835
+
1836
+ # attention
1837
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1838
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1839
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1840
+
1841
+ if image_rotary_emb is not None:
1842
+ # YiYi to-do: update uising apply_rotary_emb
1843
+ # from ..embeddings import apply_rotary_emb
1844
+ # query = apply_rotary_emb(query, image_rotary_emb)
1845
+ # key = apply_rotary_emb(key, image_rotary_emb)
1846
+ query, key = apply_rope(query, key, image_rotary_emb)
1847
+
1848
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1849
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1850
+ hidden_states = hidden_states.to(query.dtype)
1851
+
1852
+ encoder_hidden_states, hidden_states = (
1853
+ hidden_states[:, : encoder_hidden_states.shape[1]],
1854
+ hidden_states[:, encoder_hidden_states.shape[1] :],
1855
+ )
1856
+
1857
+ # linear proj
1858
+ hidden_states = attn.to_out[0](hidden_states)
1859
+ # dropout
1860
+ hidden_states = attn.to_out[1](hidden_states)
1861
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1862
+
1863
+ if input_ndim == 4:
1864
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1865
+ if context_input_ndim == 4:
1866
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1867
+
1868
+ return hidden_states, encoder_hidden_states
1869
+
1870
+
1871
+ class XFormersAttnAddedKVProcessor:
1872
+ r"""
1873
+ Processor for implementing memory efficient attention using xFormers.
1874
+
1875
+ Args:
1876
+ attention_op (`Callable`, *optional*, defaults to `None`):
1877
+ The base
1878
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1879
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1880
+ operator.
1881
+ """
1882
+
1883
+ def __init__(self, attention_op: Optional[Callable] = None):
1884
+ self.attention_op = attention_op
1885
+
1886
+ def __call__(
1887
+ self,
1888
+ attn: Attention,
1889
+ hidden_states: torch.Tensor,
1890
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1891
+ attention_mask: Optional[torch.Tensor] = None,
1892
+ ) -> torch.Tensor:
1893
+ residual = hidden_states
1894
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1895
+ batch_size, sequence_length, _ = hidden_states.shape
1896
+
1897
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1898
+
1899
+ if encoder_hidden_states is None:
1900
+ encoder_hidden_states = hidden_states
1901
+ elif attn.norm_cross:
1902
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1903
+
1904
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1905
+
1906
+ query = attn.to_q(hidden_states)
1907
+ query = attn.head_to_batch_dim(query)
1908
+
1909
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1910
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1911
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1912
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1913
+
1914
+ if not attn.only_cross_attention:
1915
+ key = attn.to_k(hidden_states)
1916
+ value = attn.to_v(hidden_states)
1917
+ key = attn.head_to_batch_dim(key)
1918
+ value = attn.head_to_batch_dim(value)
1919
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1920
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1921
+ else:
1922
+ key = encoder_hidden_states_key_proj
1923
+ value = encoder_hidden_states_value_proj
1924
+
1925
+ hidden_states = xformers.ops.memory_efficient_attention(
1926
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1927
+ )
1928
+ hidden_states = hidden_states.to(query.dtype)
1929
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1930
+
1931
+ # linear proj
1932
+ hidden_states = attn.to_out[0](hidden_states)
1933
+ # dropout
1934
+ hidden_states = attn.to_out[1](hidden_states)
1935
+
1936
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1937
+ hidden_states = hidden_states + residual
1938
+
1939
+ return hidden_states
1940
+
1941
+
1942
+ class XFormersAttnProcessor:
1943
+ r"""
1944
+ Processor for implementing memory efficient attention using xFormers.
1945
+
1946
+ Args:
1947
+ attention_op (`Callable`, *optional*, defaults to `None`):
1948
+ The base
1949
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1950
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1951
+ operator.
1952
+ """
1953
+
1954
+ def __init__(self, attention_op: Optional[Callable] = None):
1955
+ self.attention_op = attention_op
1956
+
1957
+ def __call__(
1958
+ self,
1959
+ attn: Attention,
1960
+ hidden_states: torch.Tensor,
1961
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1962
+ attention_mask: Optional[torch.Tensor] = None,
1963
+ temb: Optional[torch.Tensor] = None,
1964
+ *args,
1965
+ **kwargs,
1966
+ ) -> torch.Tensor:
1967
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1968
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1969
+ deprecate("scale", "1.0.0", deprecation_message)
1970
+
1971
+ residual = hidden_states
1972
+
1973
+ if attn.spatial_norm is not None:
1974
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1975
+
1976
+ input_ndim = hidden_states.ndim
1977
+
1978
+ if input_ndim == 4:
1979
+ batch_size, channel, height, width = hidden_states.shape
1980
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1981
+
1982
+ batch_size, key_tokens, _ = (
1983
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1984
+ )
1985
+
1986
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1987
+ if attention_mask is not None:
1988
+ # expand our mask's singleton query_tokens dimension:
1989
+ # [batch*heads, 1, key_tokens] ->
1990
+ # [batch*heads, query_tokens, key_tokens]
1991
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1992
+ # [batch*heads, query_tokens, key_tokens]
1993
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1994
+ _, query_tokens, _ = hidden_states.shape
1995
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1996
+
1997
+ if attn.group_norm is not None:
1998
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1999
+
2000
+ query = attn.to_q(hidden_states)
2001
+
2002
+ if encoder_hidden_states is None:
2003
+ encoder_hidden_states = hidden_states
2004
+ elif attn.norm_cross:
2005
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2006
+
2007
+ key = attn.to_k(encoder_hidden_states)
2008
+ value = attn.to_v(encoder_hidden_states)
2009
+
2010
+ query = attn.head_to_batch_dim(query).contiguous()
2011
+ key = attn.head_to_batch_dim(key).contiguous()
2012
+ value = attn.head_to_batch_dim(value).contiguous()
2013
+
2014
+ hidden_states = xformers.ops.memory_efficient_attention(
2015
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
2016
+ )
2017
+ hidden_states = hidden_states.to(query.dtype)
2018
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2019
+
2020
+ # linear proj
2021
+ hidden_states = attn.to_out[0](hidden_states)
2022
+ # dropout
2023
+ hidden_states = attn.to_out[1](hidden_states)
2024
+
2025
+ if input_ndim == 4:
2026
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2027
+
2028
+ if attn.residual_connection:
2029
+ hidden_states = hidden_states + residual
2030
+
2031
+ hidden_states = hidden_states / attn.rescale_output_factor
2032
+
2033
+ return hidden_states
2034
+
2035
+
2036
+ class AttnProcessorNPU:
2037
+ r"""
2038
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
2039
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
2040
+ not significant.
2041
+
2042
+ """
2043
+
2044
+ def __init__(self):
2045
+ if not is_torch_npu_available():
2046
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
2047
+
2048
+ def __call__(
2049
+ self,
2050
+ attn: Attention,
2051
+ hidden_states: torch.Tensor,
2052
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2053
+ attention_mask: Optional[torch.Tensor] = None,
2054
+ temb: Optional[torch.Tensor] = None,
2055
+ *args,
2056
+ **kwargs,
2057
+ ) -> torch.Tensor:
2058
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2059
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2060
+ deprecate("scale", "1.0.0", deprecation_message)
2061
+
2062
+ residual = hidden_states
2063
+ if attn.spatial_norm is not None:
2064
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2065
+
2066
+ input_ndim = hidden_states.ndim
2067
+
2068
+ if input_ndim == 4:
2069
+ batch_size, channel, height, width = hidden_states.shape
2070
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2071
+
2072
+ batch_size, sequence_length, _ = (
2073
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2074
+ )
2075
+
2076
+ if attention_mask is not None:
2077
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2078
+ # scaled_dot_product_attention expects attention_mask shape to be
2079
+ # (batch, heads, source_length, target_length)
2080
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2081
+
2082
+ if attn.group_norm is not None:
2083
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2084
+
2085
+ query = attn.to_q(hidden_states)
2086
+
2087
+ if encoder_hidden_states is None:
2088
+ encoder_hidden_states = hidden_states
2089
+ elif attn.norm_cross:
2090
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2091
+
2092
+ key = attn.to_k(encoder_hidden_states)
2093
+ value = attn.to_v(encoder_hidden_states)
2094
+
2095
+ inner_dim = key.shape[-1]
2096
+ head_dim = inner_dim // attn.heads
2097
+
2098
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2099
+
2100
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2101
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2102
+
2103
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2104
+ if query.dtype in (torch.float16, torch.bfloat16):
2105
+ hidden_states = torch_npu.npu_fusion_attention(
2106
+ query,
2107
+ key,
2108
+ value,
2109
+ attn.heads,
2110
+ input_layout="BNSD",
2111
+ pse=None,
2112
+ atten_mask=attention_mask,
2113
+ scale=1.0 / math.sqrt(query.shape[-1]),
2114
+ pre_tockens=65536,
2115
+ next_tockens=65536,
2116
+ keep_prob=1.0,
2117
+ sync=False,
2118
+ inner_precise=0,
2119
+ )[0]
2120
+ else:
2121
+ # TODO: add support for attn.scale when we move to Torch 2.1
2122
+ hidden_states = F.scaled_dot_product_attention(
2123
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2124
+ )
2125
+
2126
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2127
+ hidden_states = hidden_states.to(query.dtype)
2128
+
2129
+ # linear proj
2130
+ hidden_states = attn.to_out[0](hidden_states)
2131
+ # dropout
2132
+ hidden_states = attn.to_out[1](hidden_states)
2133
+
2134
+ if input_ndim == 4:
2135
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2136
+
2137
+ if attn.residual_connection:
2138
+ hidden_states = hidden_states + residual
2139
+
2140
+ hidden_states = hidden_states / attn.rescale_output_factor
2141
+
2142
+ return hidden_states
2143
+
2144
+
2145
+ class AttnProcessor2_0:
2146
+ r"""
2147
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
2148
+ """
2149
+
2150
+ def __init__(self):
2151
+ if not hasattr(F, "scaled_dot_product_attention"):
2152
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2153
+
2154
+ def __call__(
2155
+ self,
2156
+ attn: Attention,
2157
+ hidden_states: torch.Tensor,
2158
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2159
+ attention_mask: Optional[torch.Tensor] = None,
2160
+ temb: Optional[torch.Tensor] = None,
2161
+ *args,
2162
+ **kwargs,
2163
+ ) -> torch.Tensor:
2164
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
2165
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2166
+ deprecate("scale", "1.0.0", deprecation_message)
2167
+
2168
+ residual = hidden_states
2169
+ if attn.spatial_norm is not None:
2170
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2171
+
2172
+ input_ndim = hidden_states.ndim
2173
+
2174
+ if input_ndim == 4:
2175
+ batch_size, channel, height, width = hidden_states.shape
2176
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2177
+
2178
+ batch_size, sequence_length, _ = (
2179
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2180
+ )
2181
+
2182
+ if attention_mask is not None:
2183
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2184
+ # scaled_dot_product_attention expects attention_mask shape to be
2185
+ # (batch, heads, source_length, target_length)
2186
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2187
+
2188
+ if attn.group_norm is not None:
2189
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2190
+
2191
+ query = attn.to_q(hidden_states)
2192
+
2193
+ if encoder_hidden_states is None:
2194
+ encoder_hidden_states = hidden_states
2195
+ elif attn.norm_cross:
2196
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2197
+
2198
+ key = attn.to_k(encoder_hidden_states)
2199
+ value = attn.to_v(encoder_hidden_states)
2200
+
2201
+ inner_dim = key.shape[-1]
2202
+ head_dim = inner_dim // attn.heads
2203
+
2204
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2205
+
2206
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2207
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2208
+
2209
+ if attn.norm_q is not None:
2210
+ query = attn.norm_q(query)
2211
+ if attn.norm_k is not None:
2212
+ key = attn.norm_k(key)
2213
+
2214
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2215
+ # TODO: add support for attn.scale when we move to Torch 2.1
2216
+ hidden_states = F.scaled_dot_product_attention(
2217
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2218
+ )
2219
+
2220
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2221
+ hidden_states = hidden_states.to(query.dtype)
2222
+
2223
+ # linear proj
2224
+ hidden_states = attn.to_out[0](hidden_states)
2225
+ # dropout
2226
+ hidden_states = attn.to_out[1](hidden_states)
2227
+
2228
+ if input_ndim == 4:
2229
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2230
+
2231
+ if attn.residual_connection:
2232
+ hidden_states = hidden_states + residual
2233
+
2234
+ hidden_states = hidden_states / attn.rescale_output_factor
2235
+
2236
+ return hidden_states
2237
+
2238
+
2239
+ class StableAudioAttnProcessor2_0:
2240
+ r"""
2241
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2242
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
2243
+ """
2244
+
2245
+ def __init__(self):
2246
+ if not hasattr(F, "scaled_dot_product_attention"):
2247
+ raise ImportError(
2248
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2249
+ )
2250
+
2251
+ def apply_partial_rotary_emb(
2252
+ self,
2253
+ x: torch.Tensor,
2254
+ freqs_cis: Tuple[torch.Tensor],
2255
+ ) -> torch.Tensor:
2256
+ from .embeddings import apply_rotary_emb
2257
+
2258
+ rot_dim = freqs_cis[0].shape[-1]
2259
+ x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
2260
+
2261
+ x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
2262
+
2263
+ out = torch.cat((x_rotated, x_unrotated), dim=-1)
2264
+ return out
2265
+
2266
+ def __call__(
2267
+ self,
2268
+ attn: Attention,
2269
+ hidden_states: torch.Tensor,
2270
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2271
+ attention_mask: Optional[torch.Tensor] = None,
2272
+ rotary_emb: Optional[torch.Tensor] = None,
2273
+ ) -> torch.Tensor:
2274
+ from .embeddings import apply_rotary_emb
2275
+
2276
+ residual = hidden_states
2277
+
2278
+ input_ndim = hidden_states.ndim
2279
+
2280
+ if input_ndim == 4:
2281
+ batch_size, channel, height, width = hidden_states.shape
2282
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2283
+
2284
+ batch_size, sequence_length, _ = (
2285
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2286
+ )
2287
+
2288
+ if attention_mask is not None:
2289
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2290
+ # scaled_dot_product_attention expects attention_mask shape to be
2291
+ # (batch, heads, source_length, target_length)
2292
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2293
+
2294
+ query = attn.to_q(hidden_states)
2295
+
2296
+ if encoder_hidden_states is None:
2297
+ encoder_hidden_states = hidden_states
2298
+ elif attn.norm_cross:
2299
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2300
+
2301
+ key = attn.to_k(encoder_hidden_states)
2302
+ value = attn.to_v(encoder_hidden_states)
2303
+
2304
+ head_dim = query.shape[-1] // attn.heads
2305
+ kv_heads = key.shape[-1] // head_dim
2306
+
2307
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2308
+
2309
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
2310
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
2311
+
2312
+ if kv_heads != attn.heads:
2313
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
2314
+ heads_per_kv_head = attn.heads // kv_heads
2315
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
2316
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
2317
+
2318
+ if attn.norm_q is not None:
2319
+ query = attn.norm_q(query)
2320
+ if attn.norm_k is not None:
2321
+ key = attn.norm_k(key)
2322
+
2323
+ # Apply RoPE if needed
2324
+ if rotary_emb is not None:
2325
+ query_dtype = query.dtype
2326
+ key_dtype = key.dtype
2327
+ query = query.to(torch.float32)
2328
+ key = key.to(torch.float32)
2329
+
2330
+ rot_dim = rotary_emb[0].shape[-1]
2331
+ query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
2332
+ query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
2333
+
2334
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
2335
+
2336
+ if not attn.is_cross_attention:
2337
+ key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
2338
+ key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
2339
+
2340
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
2341
+
2342
+ query = query.to(query_dtype)
2343
+ key = key.to(key_dtype)
2344
+
2345
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2346
+ # TODO: add support for attn.scale when we move to Torch 2.1
2347
+ hidden_states = F.scaled_dot_product_attention(
2348
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2349
+ )
2350
+
2351
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2352
+ hidden_states = hidden_states.to(query.dtype)
2353
+
2354
+ # linear proj
2355
+ hidden_states = attn.to_out[0](hidden_states)
2356
+ # dropout
2357
+ hidden_states = attn.to_out[1](hidden_states)
2358
+
2359
+ if input_ndim == 4:
2360
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2361
+
2362
+ if attn.residual_connection:
2363
+ hidden_states = hidden_states + residual
2364
+
2365
+ hidden_states = hidden_states / attn.rescale_output_factor
2366
+
2367
+ return hidden_states
2368
+
2369
+
2370
+ class HunyuanAttnProcessor2_0:
2371
+ r"""
2372
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2373
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2374
+ """
2375
+
2376
+ def __init__(self):
2377
+ if not hasattr(F, "scaled_dot_product_attention"):
2378
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2379
+
2380
+ def __call__(
2381
+ self,
2382
+ attn: Attention,
2383
+ hidden_states: torch.Tensor,
2384
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2385
+ attention_mask: Optional[torch.Tensor] = None,
2386
+ temb: Optional[torch.Tensor] = None,
2387
+ image_rotary_emb: Optional[torch.Tensor] = None,
2388
+ ) -> torch.Tensor:
2389
+ from .embeddings import apply_rotary_emb
2390
+
2391
+ residual = hidden_states
2392
+ if attn.spatial_norm is not None:
2393
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2394
+
2395
+ input_ndim = hidden_states.ndim
1440
2396
 
1441
2397
  if input_ndim == 4:
1442
2398
  batch_size, channel, height, width = hidden_states.shape
@@ -1473,28 +2429,22 @@ class AttnProcessorNPU:
1473
2429
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1474
2430
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1475
2431
 
2432
+ if attn.norm_q is not None:
2433
+ query = attn.norm_q(query)
2434
+ if attn.norm_k is not None:
2435
+ key = attn.norm_k(key)
2436
+
2437
+ # Apply RoPE if needed
2438
+ if image_rotary_emb is not None:
2439
+ query = apply_rotary_emb(query, image_rotary_emb)
2440
+ if not attn.is_cross_attention:
2441
+ key = apply_rotary_emb(key, image_rotary_emb)
2442
+
1476
2443
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1477
- if query.dtype in (torch.float16, torch.bfloat16):
1478
- hidden_states = torch_npu.npu_fusion_attention(
1479
- query,
1480
- key,
1481
- value,
1482
- attn.heads,
1483
- input_layout="BNSD",
1484
- pse=None,
1485
- atten_mask=attention_mask,
1486
- scale=1.0 / math.sqrt(query.shape[-1]),
1487
- pre_tockens=65536,
1488
- next_tockens=65536,
1489
- keep_prob=1.0,
1490
- sync=False,
1491
- inner_precise=0,
1492
- )[0]
1493
- else:
1494
- # TODO: add support for attn.scale when we move to Torch 2.1
1495
- hidden_states = F.scaled_dot_product_attention(
1496
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1497
- )
2444
+ # TODO: add support for attn.scale when we move to Torch 2.1
2445
+ hidden_states = F.scaled_dot_product_attention(
2446
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2447
+ )
1498
2448
 
1499
2449
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1500
2450
  hidden_states = hidden_states.to(query.dtype)
@@ -1515,14 +2465,18 @@ class AttnProcessorNPU:
1515
2465
  return hidden_states
1516
2466
 
1517
2467
 
1518
- class AttnProcessor2_0:
2468
+ class FusedHunyuanAttnProcessor2_0:
1519
2469
  r"""
1520
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
2470
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
2471
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
2472
+ query and key vector.
1521
2473
  """
1522
2474
 
1523
2475
  def __init__(self):
1524
2476
  if not hasattr(F, "scaled_dot_product_attention"):
1525
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2477
+ raise ImportError(
2478
+ "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2479
+ )
1526
2480
 
1527
2481
  def __call__(
1528
2482
  self,
@@ -1531,12 +2485,9 @@ class AttnProcessor2_0:
1531
2485
  encoder_hidden_states: Optional[torch.Tensor] = None,
1532
2486
  attention_mask: Optional[torch.Tensor] = None,
1533
2487
  temb: Optional[torch.Tensor] = None,
1534
- *args,
1535
- **kwargs,
2488
+ image_rotary_emb: Optional[torch.Tensor] = None,
1536
2489
  ) -> torch.Tensor:
1537
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1538
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1539
- deprecate("scale", "1.0.0", deprecation_message)
2490
+ from .embeddings import apply_rotary_emb
1540
2491
 
1541
2492
  residual = hidden_states
1542
2493
  if attn.spatial_norm is not None:
@@ -1561,24 +2512,37 @@ class AttnProcessor2_0:
1561
2512
  if attn.group_norm is not None:
1562
2513
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1563
2514
 
1564
- query = attn.to_q(hidden_states)
1565
-
1566
2515
  if encoder_hidden_states is None:
1567
- encoder_hidden_states = hidden_states
1568
- elif attn.norm_cross:
1569
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2516
+ qkv = attn.to_qkv(hidden_states)
2517
+ split_size = qkv.shape[-1] // 3
2518
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2519
+ else:
2520
+ if attn.norm_cross:
2521
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2522
+ query = attn.to_q(hidden_states)
1570
2523
 
1571
- key = attn.to_k(encoder_hidden_states)
1572
- value = attn.to_v(encoder_hidden_states)
2524
+ kv = attn.to_kv(encoder_hidden_states)
2525
+ split_size = kv.shape[-1] // 2
2526
+ key, value = torch.split(kv, split_size, dim=-1)
1573
2527
 
1574
2528
  inner_dim = key.shape[-1]
1575
2529
  head_dim = inner_dim // attn.heads
1576
2530
 
1577
2531
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1578
-
1579
2532
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1580
2533
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1581
2534
 
2535
+ if attn.norm_q is not None:
2536
+ query = attn.norm_q(query)
2537
+ if attn.norm_k is not None:
2538
+ key = attn.norm_k(key)
2539
+
2540
+ # Apply RoPE if needed
2541
+ if image_rotary_emb is not None:
2542
+ query = apply_rotary_emb(query, image_rotary_emb)
2543
+ if not attn.is_cross_attention:
2544
+ key = apply_rotary_emb(key, image_rotary_emb)
2545
+
1582
2546
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1583
2547
  # TODO: add support for attn.scale when we move to Torch 2.1
1584
2548
  hidden_states = F.scaled_dot_product_attention(
@@ -1604,15 +2568,18 @@ class AttnProcessor2_0:
1604
2568
  return hidden_states
1605
2569
 
1606
2570
 
1607
- class HunyuanAttnProcessor2_0:
2571
+ class PAGHunyuanAttnProcessor2_0:
1608
2572
  r"""
1609
2573
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1610
- used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2574
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2575
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
1611
2576
  """
1612
2577
 
1613
2578
  def __init__(self):
1614
2579
  if not hasattr(F, "scaled_dot_product_attention"):
1615
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2580
+ raise ImportError(
2581
+ "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2582
+ )
1616
2583
 
1617
2584
  def __call__(
1618
2585
  self,
@@ -1635,8 +2602,12 @@ class HunyuanAttnProcessor2_0:
1635
2602
  batch_size, channel, height, width = hidden_states.shape
1636
2603
  hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1637
2604
 
2605
+ # chunk
2606
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
2607
+
2608
+ # 1. Original Path
1638
2609
  batch_size, sequence_length, _ = (
1639
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2610
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1640
2611
  )
1641
2612
 
1642
2613
  if attention_mask is not None:
@@ -1646,12 +2617,12 @@ class HunyuanAttnProcessor2_0:
1646
2617
  attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1647
2618
 
1648
2619
  if attn.group_norm is not None:
1649
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2620
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
1650
2621
 
1651
- query = attn.to_q(hidden_states)
2622
+ query = attn.to_q(hidden_states_org)
1652
2623
 
1653
2624
  if encoder_hidden_states is None:
1654
- encoder_hidden_states = hidden_states
2625
+ encoder_hidden_states = hidden_states_org
1655
2626
  elif attn.norm_cross:
1656
2627
  encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1657
2628
 
@@ -1679,25 +2650,263 @@ class HunyuanAttnProcessor2_0:
1679
2650
 
1680
2651
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1681
2652
  # TODO: add support for attn.scale when we move to Torch 2.1
1682
- hidden_states = F.scaled_dot_product_attention(
2653
+ hidden_states_org = F.scaled_dot_product_attention(
1683
2654
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1684
2655
  )
1685
2656
 
1686
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1687
- hidden_states = hidden_states.to(query.dtype)
2657
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2658
+ hidden_states_org = hidden_states_org.to(query.dtype)
1688
2659
 
1689
2660
  # linear proj
1690
- hidden_states = attn.to_out[0](hidden_states)
2661
+ hidden_states_org = attn.to_out[0](hidden_states_org)
1691
2662
  # dropout
1692
- hidden_states = attn.to_out[1](hidden_states)
2663
+ hidden_states_org = attn.to_out[1](hidden_states_org)
2664
+
2665
+ if input_ndim == 4:
2666
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2667
+
2668
+ # 2. Perturbed Path
2669
+ if attn.group_norm is not None:
2670
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2671
+
2672
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
2673
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2674
+
2675
+ # linear proj
2676
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2677
+ # dropout
2678
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2679
+
2680
+ if input_ndim == 4:
2681
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2682
+
2683
+ # cat
2684
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2685
+
2686
+ if attn.residual_connection:
2687
+ hidden_states = hidden_states + residual
2688
+
2689
+ hidden_states = hidden_states / attn.rescale_output_factor
2690
+
2691
+ return hidden_states
2692
+
2693
+
2694
+ class PAGCFGHunyuanAttnProcessor2_0:
2695
+ r"""
2696
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2697
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
2698
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
2699
+ """
2700
+
2701
+ def __init__(self):
2702
+ if not hasattr(F, "scaled_dot_product_attention"):
2703
+ raise ImportError(
2704
+ "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2705
+ )
2706
+
2707
+ def __call__(
2708
+ self,
2709
+ attn: Attention,
2710
+ hidden_states: torch.Tensor,
2711
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2712
+ attention_mask: Optional[torch.Tensor] = None,
2713
+ temb: Optional[torch.Tensor] = None,
2714
+ image_rotary_emb: Optional[torch.Tensor] = None,
2715
+ ) -> torch.Tensor:
2716
+ from .embeddings import apply_rotary_emb
2717
+
2718
+ residual = hidden_states
2719
+ if attn.spatial_norm is not None:
2720
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2721
+
2722
+ input_ndim = hidden_states.ndim
2723
+
2724
+ if input_ndim == 4:
2725
+ batch_size, channel, height, width = hidden_states.shape
2726
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2727
+
2728
+ # chunk
2729
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
2730
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
2731
+
2732
+ # 1. Original Path
2733
+ batch_size, sequence_length, _ = (
2734
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2735
+ )
2736
+
2737
+ if attention_mask is not None:
2738
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2739
+ # scaled_dot_product_attention expects attention_mask shape to be
2740
+ # (batch, heads, source_length, target_length)
2741
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2742
+
2743
+ if attn.group_norm is not None:
2744
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
2745
+
2746
+ query = attn.to_q(hidden_states_org)
2747
+
2748
+ if encoder_hidden_states is None:
2749
+ encoder_hidden_states = hidden_states_org
2750
+ elif attn.norm_cross:
2751
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2752
+
2753
+ key = attn.to_k(encoder_hidden_states)
2754
+ value = attn.to_v(encoder_hidden_states)
2755
+
2756
+ inner_dim = key.shape[-1]
2757
+ head_dim = inner_dim // attn.heads
2758
+
2759
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2760
+
2761
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2762
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2763
+
2764
+ if attn.norm_q is not None:
2765
+ query = attn.norm_q(query)
2766
+ if attn.norm_k is not None:
2767
+ key = attn.norm_k(key)
2768
+
2769
+ # Apply RoPE if needed
2770
+ if image_rotary_emb is not None:
2771
+ query = apply_rotary_emb(query, image_rotary_emb)
2772
+ if not attn.is_cross_attention:
2773
+ key = apply_rotary_emb(key, image_rotary_emb)
2774
+
2775
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2776
+ # TODO: add support for attn.scale when we move to Torch 2.1
2777
+ hidden_states_org = F.scaled_dot_product_attention(
2778
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2779
+ )
2780
+
2781
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2782
+ hidden_states_org = hidden_states_org.to(query.dtype)
2783
+
2784
+ # linear proj
2785
+ hidden_states_org = attn.to_out[0](hidden_states_org)
2786
+ # dropout
2787
+ hidden_states_org = attn.to_out[1](hidden_states_org)
2788
+
2789
+ if input_ndim == 4:
2790
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
2791
+
2792
+ # 2. Perturbed Path
2793
+ if attn.group_norm is not None:
2794
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
2795
+
2796
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
2797
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
2798
+
2799
+ # linear proj
2800
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
2801
+ # dropout
2802
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
2803
+
2804
+ if input_ndim == 4:
2805
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
2806
+
2807
+ # cat
2808
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
2809
+
2810
+ if attn.residual_connection:
2811
+ hidden_states = hidden_states + residual
2812
+
2813
+ hidden_states = hidden_states / attn.rescale_output_factor
2814
+
2815
+ return hidden_states
2816
+
2817
+
2818
+ class LuminaAttnProcessor2_0:
2819
+ r"""
2820
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
2821
+ used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
2822
+ """
2823
+
2824
+ def __init__(self):
2825
+ if not hasattr(F, "scaled_dot_product_attention"):
2826
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2827
+
2828
+ def __call__(
2829
+ self,
2830
+ attn: Attention,
2831
+ hidden_states: torch.Tensor,
2832
+ encoder_hidden_states: torch.Tensor,
2833
+ attention_mask: Optional[torch.Tensor] = None,
2834
+ query_rotary_emb: Optional[torch.Tensor] = None,
2835
+ key_rotary_emb: Optional[torch.Tensor] = None,
2836
+ base_sequence_length: Optional[int] = None,
2837
+ ) -> torch.Tensor:
2838
+ from .embeddings import apply_rotary_emb
2839
+
2840
+ input_ndim = hidden_states.ndim
2841
+
2842
+ if input_ndim == 4:
2843
+ batch_size, channel, height, width = hidden_states.shape
2844
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2845
+
2846
+ batch_size, sequence_length, _ = hidden_states.shape
2847
+
2848
+ # Get Query-Key-Value Pair
2849
+ query = attn.to_q(hidden_states)
2850
+ key = attn.to_k(encoder_hidden_states)
2851
+ value = attn.to_v(encoder_hidden_states)
2852
+
2853
+ query_dim = query.shape[-1]
2854
+ inner_dim = key.shape[-1]
2855
+ head_dim = query_dim // attn.heads
2856
+ dtype = query.dtype
2857
+
2858
+ # Get key-value heads
2859
+ kv_heads = inner_dim // head_dim
2860
+
2861
+ # Apply Query-Key Norm if needed
2862
+ if attn.norm_q is not None:
2863
+ query = attn.norm_q(query)
2864
+ if attn.norm_k is not None:
2865
+ key = attn.norm_k(key)
2866
+
2867
+ query = query.view(batch_size, -1, attn.heads, head_dim)
2868
+
2869
+ key = key.view(batch_size, -1, kv_heads, head_dim)
2870
+ value = value.view(batch_size, -1, kv_heads, head_dim)
2871
+
2872
+ # Apply RoPE if needed
2873
+ if query_rotary_emb is not None:
2874
+ query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
2875
+ if key_rotary_emb is not None:
2876
+ key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
2877
+
2878
+ query, key = query.to(dtype), key.to(dtype)
2879
+
2880
+ # Apply proportional attention if true
2881
+ if key_rotary_emb is None:
2882
+ softmax_scale = None
2883
+ else:
2884
+ if base_sequence_length is not None:
2885
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
2886
+ else:
2887
+ softmax_scale = attn.scale
1693
2888
 
1694
- if input_ndim == 4:
1695
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2889
+ # perform Grouped-qurey Attention (GQA)
2890
+ n_rep = attn.heads // kv_heads
2891
+ if n_rep >= 1:
2892
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
2893
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
1696
2894
 
1697
- if attn.residual_connection:
1698
- hidden_states = hidden_states + residual
2895
+ # scaled_dot_product_attention expects attention_mask shape to be
2896
+ # (batch, heads, source_length, target_length)
2897
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
2898
+ attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
1699
2899
 
1700
- hidden_states = hidden_states / attn.rescale_output_factor
2900
+ query = query.transpose(1, 2)
2901
+ key = key.transpose(1, 2)
2902
+ value = value.transpose(1, 2)
2903
+
2904
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2905
+ # TODO: add support for attn.scale when we move to Torch 2.1
2906
+ hidden_states = F.scaled_dot_product_attention(
2907
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
2908
+ )
2909
+ hidden_states = hidden_states.transpose(1, 2).to(dtype)
1701
2910
 
1702
2911
  return hidden_states
1703
2912
 
@@ -1778,6 +2987,11 @@ class FusedAttnProcessor2_0:
1778
2987
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1779
2988
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1780
2989
 
2990
+ if attn.norm_q is not None:
2991
+ query = attn.norm_q(query)
2992
+ if attn.norm_k is not None:
2993
+ key = attn.norm_k(key)
2994
+
1781
2995
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1782
2996
  # TODO: add support for attn.scale when we move to Torch 2.1
1783
2997
  hidden_states = F.scaled_dot_product_attention(
@@ -2088,7 +3302,7 @@ class SlicedAttnProcessor:
2088
3302
  (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
2089
3303
  )
2090
3304
 
2091
- for i in range(batch_size_attention // self.slice_size):
3305
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
2092
3306
  start_idx = i * self.slice_size
2093
3307
  end_idx = (i + 1) * self.slice_size
2094
3308
 
@@ -2185,7 +3399,7 @@ class SlicedAttnAddedKVProcessor:
2185
3399
  (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
2186
3400
  )
2187
3401
 
2188
- for i in range(batch_size_attention // self.slice_size):
3402
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
2189
3403
  start_idx = i * self.slice_size
2190
3404
  end_idx = (i + 1) * self.slice_size
2191
3405
 
@@ -2241,264 +3455,6 @@ class SpatialNorm(nn.Module):
2241
3455
  return new_f
2242
3456
 
2243
3457
 
2244
- class LoRAAttnProcessor(nn.Module):
2245
- def __init__(
2246
- self,
2247
- hidden_size: int,
2248
- cross_attention_dim: Optional[int] = None,
2249
- rank: int = 4,
2250
- network_alpha: Optional[int] = None,
2251
- **kwargs,
2252
- ):
2253
- deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
2254
- deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
2255
-
2256
- super().__init__()
2257
-
2258
- self.hidden_size = hidden_size
2259
- self.cross_attention_dim = cross_attention_dim
2260
- self.rank = rank
2261
-
2262
- q_rank = kwargs.pop("q_rank", None)
2263
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2264
- q_rank = q_rank if q_rank is not None else rank
2265
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2266
-
2267
- v_rank = kwargs.pop("v_rank", None)
2268
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2269
- v_rank = v_rank if v_rank is not None else rank
2270
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2271
-
2272
- out_rank = kwargs.pop("out_rank", None)
2273
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2274
- out_rank = out_rank if out_rank is not None else rank
2275
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2276
-
2277
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2278
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2279
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2280
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2281
-
2282
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2283
- self_cls_name = self.__class__.__name__
2284
- deprecate(
2285
- self_cls_name,
2286
- "0.26.0",
2287
- (
2288
- f"Make sure use {self_cls_name[4:]} instead by setting"
2289
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2290
- " `LoraLoaderMixin.load_lora_weights`"
2291
- ),
2292
- )
2293
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2294
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2295
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2296
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2297
-
2298
- attn._modules.pop("processor")
2299
- attn.processor = AttnProcessor()
2300
- return attn.processor(attn, hidden_states, **kwargs)
2301
-
2302
-
2303
- class LoRAAttnProcessor2_0(nn.Module):
2304
- def __init__(
2305
- self,
2306
- hidden_size: int,
2307
- cross_attention_dim: Optional[int] = None,
2308
- rank: int = 4,
2309
- network_alpha: Optional[int] = None,
2310
- **kwargs,
2311
- ):
2312
- deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
2313
- deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
2314
-
2315
- super().__init__()
2316
- if not hasattr(F, "scaled_dot_product_attention"):
2317
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2318
-
2319
- self.hidden_size = hidden_size
2320
- self.cross_attention_dim = cross_attention_dim
2321
- self.rank = rank
2322
-
2323
- q_rank = kwargs.pop("q_rank", None)
2324
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2325
- q_rank = q_rank if q_rank is not None else rank
2326
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2327
-
2328
- v_rank = kwargs.pop("v_rank", None)
2329
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2330
- v_rank = v_rank if v_rank is not None else rank
2331
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2332
-
2333
- out_rank = kwargs.pop("out_rank", None)
2334
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2335
- out_rank = out_rank if out_rank is not None else rank
2336
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2337
-
2338
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2339
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2340
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2341
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2342
-
2343
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2344
- self_cls_name = self.__class__.__name__
2345
- deprecate(
2346
- self_cls_name,
2347
- "0.26.0",
2348
- (
2349
- f"Make sure use {self_cls_name[4:]} instead by setting"
2350
- "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2351
- " `LoraLoaderMixin.load_lora_weights`"
2352
- ),
2353
- )
2354
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2355
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2356
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2357
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2358
-
2359
- attn._modules.pop("processor")
2360
- attn.processor = AttnProcessor2_0()
2361
- return attn.processor(attn, hidden_states, **kwargs)
2362
-
2363
-
2364
- class LoRAXFormersAttnProcessor(nn.Module):
2365
- r"""
2366
- Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
2367
-
2368
- Args:
2369
- hidden_size (`int`, *optional*):
2370
- The hidden size of the attention layer.
2371
- cross_attention_dim (`int`, *optional*):
2372
- The number of channels in the `encoder_hidden_states`.
2373
- rank (`int`, defaults to 4):
2374
- The dimension of the LoRA update matrices.
2375
- attention_op (`Callable`, *optional*, defaults to `None`):
2376
- The base
2377
- [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
2378
- use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
2379
- operator.
2380
- network_alpha (`int`, *optional*):
2381
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
2382
- kwargs (`dict`):
2383
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
2384
- """
2385
-
2386
- def __init__(
2387
- self,
2388
- hidden_size: int,
2389
- cross_attention_dim: int,
2390
- rank: int = 4,
2391
- attention_op: Optional[Callable] = None,
2392
- network_alpha: Optional[int] = None,
2393
- **kwargs,
2394
- ):
2395
- super().__init__()
2396
-
2397
- self.hidden_size = hidden_size
2398
- self.cross_attention_dim = cross_attention_dim
2399
- self.rank = rank
2400
- self.attention_op = attention_op
2401
-
2402
- q_rank = kwargs.pop("q_rank", None)
2403
- q_hidden_size = kwargs.pop("q_hidden_size", None)
2404
- q_rank = q_rank if q_rank is not None else rank
2405
- q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
2406
-
2407
- v_rank = kwargs.pop("v_rank", None)
2408
- v_hidden_size = kwargs.pop("v_hidden_size", None)
2409
- v_rank = v_rank if v_rank is not None else rank
2410
- v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
2411
-
2412
- out_rank = kwargs.pop("out_rank", None)
2413
- out_hidden_size = kwargs.pop("out_hidden_size", None)
2414
- out_rank = out_rank if out_rank is not None else rank
2415
- out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
2416
-
2417
- self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
2418
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2419
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2420
- self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2421
-
2422
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2423
- self_cls_name = self.__class__.__name__
2424
- deprecate(
2425
- self_cls_name,
2426
- "0.26.0",
2427
- (
2428
- f"Make sure use {self_cls_name[4:]} instead by setting"
2429
- "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2430
- " `LoraLoaderMixin.load_lora_weights`"
2431
- ),
2432
- )
2433
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2434
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2435
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2436
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2437
-
2438
- attn._modules.pop("processor")
2439
- attn.processor = XFormersAttnProcessor()
2440
- return attn.processor(attn, hidden_states, **kwargs)
2441
-
2442
-
2443
- class LoRAAttnAddedKVProcessor(nn.Module):
2444
- r"""
2445
- Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
2446
- encoder.
2447
-
2448
- Args:
2449
- hidden_size (`int`, *optional*):
2450
- The hidden size of the attention layer.
2451
- cross_attention_dim (`int`, *optional*, defaults to `None`):
2452
- The number of channels in the `encoder_hidden_states`.
2453
- rank (`int`, defaults to 4):
2454
- The dimension of the LoRA update matrices.
2455
- network_alpha (`int`, *optional*):
2456
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
2457
- kwargs (`dict`):
2458
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
2459
- """
2460
-
2461
- def __init__(
2462
- self,
2463
- hidden_size: int,
2464
- cross_attention_dim: Optional[int] = None,
2465
- rank: int = 4,
2466
- network_alpha: Optional[int] = None,
2467
- ):
2468
- super().__init__()
2469
-
2470
- self.hidden_size = hidden_size
2471
- self.cross_attention_dim = cross_attention_dim
2472
- self.rank = rank
2473
-
2474
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2475
- self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2476
- self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
2477
- self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2478
- self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2479
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2480
-
2481
- def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2482
- self_cls_name = self.__class__.__name__
2483
- deprecate(
2484
- self_cls_name,
2485
- "0.26.0",
2486
- (
2487
- f"Make sure use {self_cls_name[4:]} instead by setting"
2488
- "LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
2489
- " `LoraLoaderMixin.load_lora_weights`"
2490
- ),
2491
- )
2492
- attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
2493
- attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
2494
- attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
2495
- attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
2496
-
2497
- attn._modules.pop("processor")
2498
- attn.processor = AttnAddedKVProcessor()
2499
- return attn.processor(attn, hidden_states, **kwargs)
2500
-
2501
-
2502
3458
  class IPAdapterAttnProcessor(nn.Module):
2503
3459
  r"""
2504
3460
  Attention processor for Multiple IP-Adapters.
@@ -2927,19 +3883,233 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2927
3883
  return hidden_states
2928
3884
 
2929
3885
 
2930
- LORA_ATTENTION_PROCESSORS = (
2931
- LoRAAttnProcessor,
2932
- LoRAAttnProcessor2_0,
2933
- LoRAXFormersAttnProcessor,
2934
- LoRAAttnAddedKVProcessor,
2935
- )
3886
+ class PAGIdentitySelfAttnProcessor2_0:
3887
+ r"""
3888
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3889
+ PAG reference: https://arxiv.org/abs/2403.17377
3890
+ """
3891
+
3892
+ def __init__(self):
3893
+ if not hasattr(F, "scaled_dot_product_attention"):
3894
+ raise ImportError(
3895
+ "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3896
+ )
3897
+
3898
+ def __call__(
3899
+ self,
3900
+ attn: Attention,
3901
+ hidden_states: torch.FloatTensor,
3902
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
3903
+ attention_mask: Optional[torch.FloatTensor] = None,
3904
+ temb: Optional[torch.FloatTensor] = None,
3905
+ ) -> torch.Tensor:
3906
+ residual = hidden_states
3907
+ if attn.spatial_norm is not None:
3908
+ hidden_states = attn.spatial_norm(hidden_states, temb)
3909
+
3910
+ input_ndim = hidden_states.ndim
3911
+ if input_ndim == 4:
3912
+ batch_size, channel, height, width = hidden_states.shape
3913
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3914
+
3915
+ # chunk
3916
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
3917
+
3918
+ # original path
3919
+ batch_size, sequence_length, _ = hidden_states_org.shape
3920
+
3921
+ if attention_mask is not None:
3922
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3923
+ # scaled_dot_product_attention expects attention_mask shape to be
3924
+ # (batch, heads, source_length, target_length)
3925
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3926
+
3927
+ if attn.group_norm is not None:
3928
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
3929
+
3930
+ query = attn.to_q(hidden_states_org)
3931
+ key = attn.to_k(hidden_states_org)
3932
+ value = attn.to_v(hidden_states_org)
3933
+
3934
+ inner_dim = key.shape[-1]
3935
+ head_dim = inner_dim // attn.heads
3936
+
3937
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3938
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3939
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3940
+
3941
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3942
+ # TODO: add support for attn.scale when we move to Torch 2.1
3943
+ hidden_states_org = F.scaled_dot_product_attention(
3944
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3945
+ )
3946
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3947
+ hidden_states_org = hidden_states_org.to(query.dtype)
3948
+
3949
+ # linear proj
3950
+ hidden_states_org = attn.to_out[0](hidden_states_org)
3951
+ # dropout
3952
+ hidden_states_org = attn.to_out[1](hidden_states_org)
3953
+
3954
+ if input_ndim == 4:
3955
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
3956
+
3957
+ # perturbed path (identity attention)
3958
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
3959
+
3960
+ if attn.group_norm is not None:
3961
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
3962
+
3963
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
3964
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
3965
+
3966
+ # linear proj
3967
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
3968
+ # dropout
3969
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
3970
+
3971
+ if input_ndim == 4:
3972
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
3973
+
3974
+ # cat
3975
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
3976
+
3977
+ if attn.residual_connection:
3978
+ hidden_states = hidden_states + residual
3979
+
3980
+ hidden_states = hidden_states / attn.rescale_output_factor
3981
+
3982
+ return hidden_states
3983
+
3984
+
3985
+ class PAGCFGIdentitySelfAttnProcessor2_0:
3986
+ r"""
3987
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3988
+ PAG reference: https://arxiv.org/abs/2403.17377
3989
+ """
3990
+
3991
+ def __init__(self):
3992
+ if not hasattr(F, "scaled_dot_product_attention"):
3993
+ raise ImportError(
3994
+ "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3995
+ )
3996
+
3997
+ def __call__(
3998
+ self,
3999
+ attn: Attention,
4000
+ hidden_states: torch.FloatTensor,
4001
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
4002
+ attention_mask: Optional[torch.FloatTensor] = None,
4003
+ temb: Optional[torch.FloatTensor] = None,
4004
+ ) -> torch.Tensor:
4005
+ residual = hidden_states
4006
+ if attn.spatial_norm is not None:
4007
+ hidden_states = attn.spatial_norm(hidden_states, temb)
4008
+
4009
+ input_ndim = hidden_states.ndim
4010
+ if input_ndim == 4:
4011
+ batch_size, channel, height, width = hidden_states.shape
4012
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4013
+
4014
+ # chunk
4015
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
4016
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
4017
+
4018
+ # original path
4019
+ batch_size, sequence_length, _ = hidden_states_org.shape
4020
+
4021
+ if attention_mask is not None:
4022
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4023
+ # scaled_dot_product_attention expects attention_mask shape to be
4024
+ # (batch, heads, source_length, target_length)
4025
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
4026
+
4027
+ if attn.group_norm is not None:
4028
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
4029
+
4030
+ query = attn.to_q(hidden_states_org)
4031
+ key = attn.to_k(hidden_states_org)
4032
+ value = attn.to_v(hidden_states_org)
4033
+
4034
+ inner_dim = key.shape[-1]
4035
+ head_dim = inner_dim // attn.heads
4036
+
4037
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4038
+
4039
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4040
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4041
+
4042
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
4043
+ # TODO: add support for attn.scale when we move to Torch 2.1
4044
+ hidden_states_org = F.scaled_dot_product_attention(
4045
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
4046
+ )
4047
+
4048
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
4049
+ hidden_states_org = hidden_states_org.to(query.dtype)
4050
+
4051
+ # linear proj
4052
+ hidden_states_org = attn.to_out[0](hidden_states_org)
4053
+ # dropout
4054
+ hidden_states_org = attn.to_out[1](hidden_states_org)
4055
+
4056
+ if input_ndim == 4:
4057
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
4058
+
4059
+ # perturbed path (identity attention)
4060
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
4061
+
4062
+ if attn.group_norm is not None:
4063
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
4064
+
4065
+ value = attn.to_v(hidden_states_ptb)
4066
+ hidden_states_ptb = value
4067
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
4068
+
4069
+ # linear proj
4070
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
4071
+ # dropout
4072
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
4073
+
4074
+ if input_ndim == 4:
4075
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
4076
+
4077
+ # cat
4078
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
4079
+
4080
+ if attn.residual_connection:
4081
+ hidden_states = hidden_states + residual
4082
+
4083
+ hidden_states = hidden_states / attn.rescale_output_factor
4084
+
4085
+ return hidden_states
4086
+
4087
+
4088
+ class LoRAAttnProcessor:
4089
+ def __init__(self):
4090
+ pass
4091
+
4092
+
4093
+ class LoRAAttnProcessor2_0:
4094
+ def __init__(self):
4095
+ pass
4096
+
4097
+
4098
+ class LoRAXFormersAttnProcessor:
4099
+ def __init__(self):
4100
+ pass
4101
+
4102
+
4103
+ class LoRAAttnAddedKVProcessor:
4104
+ def __init__(self):
4105
+ pass
4106
+
2936
4107
 
2937
4108
  ADDED_KV_ATTENTION_PROCESSORS = (
2938
4109
  AttnAddedKVProcessor,
2939
4110
  SlicedAttnAddedKVProcessor,
2940
4111
  AttnAddedKVProcessor2_0,
2941
4112
  XFormersAttnAddedKVProcessor,
2942
- LoRAAttnAddedKVProcessor,
2943
4113
  )
2944
4114
 
2945
4115
  CROSS_ATTENTION_PROCESSORS = (
@@ -2947,9 +4117,6 @@ CROSS_ATTENTION_PROCESSORS = (
2947
4117
  AttnProcessor2_0,
2948
4118
  XFormersAttnProcessor,
2949
4119
  SlicedAttnProcessor,
2950
- LoRAAttnProcessor,
2951
- LoRAAttnProcessor2_0,
2952
- LoRAXFormersAttnProcessor,
2953
4120
  IPAdapterAttnProcessor,
2954
4121
  IPAdapterAttnProcessor2_0,
2955
4122
  )
@@ -2967,9 +4134,8 @@ AttentionProcessor = Union[
2967
4134
  CustomDiffusionAttnProcessor,
2968
4135
  CustomDiffusionXFormersAttnProcessor,
2969
4136
  CustomDiffusionAttnProcessor2_0,
2970
- # deprecated
2971
- LoRAAttnProcessor,
2972
- LoRAAttnProcessor2_0,
2973
- LoRAXFormersAttnProcessor,
2974
- LoRAAttnAddedKVProcessor,
4137
+ PAGCFGIdentitySelfAttnProcessor2_0,
4138
+ PAGIdentitySelfAttnProcessor2_0,
4139
+ PAGCFGHunyuanAttnProcessor2_0,
4140
+ PAGHunyuanAttnProcessor2_0,
2975
4141
  ]