diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,8 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
 
22
22
  from ..image_processor import IPAdapterMaskProcessor
23
- from ..utils import deprecate, logging
24
- from ..utils.import_utils import is_torch_npu_available, is_xformers_available
23
+ from ..utils import deprecate, is_torch_xla_available, logging
24
+ from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
25
25
  from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
26
26
 
27
27
 
@@ -36,6 +36,15 @@ if is_xformers_available():
36
36
  else:
37
37
  xformers = None
38
38
 
39
+ if is_torch_xla_available():
40
+ # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
41
+ if is_torch_xla_version(">", "2.2"):
42
+ from torch_xla.experimental.custom_kernel import flash_attention
43
+ from torch_xla.runtime import is_spmd
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
39
48
 
40
49
  @maybe_allow_in_graph
41
50
  class Attention(nn.Module):
@@ -120,14 +129,16 @@ class Attention(nn.Module):
120
129
  _from_deprecated_attn_block: bool = False,
121
130
  processor: Optional["AttnProcessor"] = None,
122
131
  out_dim: int = None,
132
+ out_context_dim: int = None,
123
133
  context_pre_only=None,
124
134
  pre_only=False,
125
135
  elementwise_affine: bool = True,
136
+ is_causal: bool = False,
126
137
  ):
127
138
  super().__init__()
128
139
 
129
140
  # To prevent circular import.
130
- from .normalization import FP32LayerNorm, RMSNorm
141
+ from .normalization import FP32LayerNorm, LpNorm, RMSNorm
131
142
 
132
143
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
133
144
  self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -142,8 +153,10 @@ class Attention(nn.Module):
142
153
  self.dropout = dropout
143
154
  self.fused_projections = False
144
155
  self.out_dim = out_dim if out_dim is not None else query_dim
156
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
145
157
  self.context_pre_only = context_pre_only
146
158
  self.pre_only = pre_only
159
+ self.is_causal = is_causal
147
160
 
148
161
  # we make use of this private variable to know whether this class is loaded
149
162
  # with an deprecated state dict so that we can convert it on the fly
@@ -186,12 +199,19 @@ class Attention(nn.Module):
186
199
  self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
187
200
  self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
188
201
  elif qk_norm == "layer_norm_across_heads":
189
- # Lumina applys qk norm across all heads
202
+ # Lumina applies qk norm across all heads
190
203
  self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
191
204
  self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
192
205
  elif qk_norm == "rms_norm":
193
206
  self.norm_q = RMSNorm(dim_head, eps=eps)
194
207
  self.norm_k = RMSNorm(dim_head, eps=eps)
208
+ elif qk_norm == "rms_norm_across_heads":
209
+ # LTX applies qk norm across all heads
210
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
211
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
212
+ elif qk_norm == "l2":
213
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
214
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
195
215
  else:
196
216
  raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
197
217
 
@@ -234,14 +254,22 @@ class Attention(nn.Module):
234
254
  self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
235
255
  if self.context_pre_only is not None:
236
256
  self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
257
+ else:
258
+ self.add_q_proj = None
259
+ self.add_k_proj = None
260
+ self.add_v_proj = None
237
261
 
238
262
  if not self.pre_only:
239
263
  self.to_out = nn.ModuleList([])
240
264
  self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
241
265
  self.to_out.append(nn.Dropout(dropout))
266
+ else:
267
+ self.to_out = None
242
268
 
243
269
  if self.context_pre_only is not None and not self.context_pre_only:
244
- self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
270
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
271
+ else:
272
+ self.to_add_out = None
245
273
 
246
274
  if qk_norm is not None and added_kv_proj_dim is not None:
247
275
  if qk_norm == "fp32_layer_norm":
@@ -268,6 +296,33 @@ class Attention(nn.Module):
268
296
  )
269
297
  self.set_processor(processor)
270
298
 
299
+ def set_use_xla_flash_attention(
300
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
301
+ ) -> None:
302
+ r"""
303
+ Set whether to use xla flash attention from `torch_xla` or not.
304
+
305
+ Args:
306
+ use_xla_flash_attention (`bool`):
307
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
308
+ partition_spec (`Tuple[]`, *optional*):
309
+ Specify the partition specification if using SPMD. Otherwise None.
310
+ """
311
+ if use_xla_flash_attention:
312
+ if not is_torch_xla_available:
313
+ raise "torch_xla is not available"
314
+ elif is_torch_xla_version("<", "2.3"):
315
+ raise "flash attention pallas kernel is supported from torch_xla version 2.3"
316
+ elif is_spmd() and is_torch_xla_version("<", "2.4"):
317
+ raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318
+ else:
319
+ processor = XLAFlashAttnProcessor2_0(partition_spec)
320
+ else:
321
+ processor = (
322
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
323
+ )
324
+ self.set_processor(processor)
325
+
271
326
  def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
272
327
  r"""
273
328
  Set whether to use npu flash attention from `torch_npu` or not.
@@ -311,6 +366,17 @@ class Attention(nn.Module):
311
366
  XFormersAttnAddedKVProcessor,
312
367
  ),
313
368
  )
369
+ is_ip_adapter = hasattr(self, "processor") and isinstance(
370
+ self.processor,
371
+ (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
372
+ )
373
+ is_joint_processor = hasattr(self, "processor") and isinstance(
374
+ self.processor,
375
+ (
376
+ JointAttnProcessor2_0,
377
+ XFormersJointAttnProcessor,
378
+ ),
379
+ )
314
380
 
315
381
  if use_memory_efficient_attention_xformers:
316
382
  if is_added_kv_processor and is_custom_diffusion:
@@ -361,6 +427,21 @@ class Attention(nn.Module):
361
427
  "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
362
428
  )
363
429
  processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
430
+ elif is_ip_adapter:
431
+ processor = IPAdapterXFormersAttnProcessor(
432
+ hidden_size=self.processor.hidden_size,
433
+ cross_attention_dim=self.processor.cross_attention_dim,
434
+ num_tokens=self.processor.num_tokens,
435
+ scale=self.processor.scale,
436
+ attention_op=attention_op,
437
+ )
438
+ processor.load_state_dict(self.processor.state_dict())
439
+ if hasattr(self.processor, "to_k_ip"):
440
+ processor.to(
441
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
442
+ )
443
+ elif is_joint_processor:
444
+ processor = XFormersJointAttnProcessor(attention_op=attention_op)
364
445
  else:
365
446
  processor = XFormersAttnProcessor(attention_op=attention_op)
366
447
  else:
@@ -379,6 +460,18 @@ class Attention(nn.Module):
379
460
  processor.load_state_dict(self.processor.state_dict())
380
461
  if hasattr(self.processor, "to_k_custom_diffusion"):
381
462
  processor.to(self.processor.to_k_custom_diffusion.weight.device)
463
+ elif is_ip_adapter:
464
+ processor = IPAdapterAttnProcessor2_0(
465
+ hidden_size=self.processor.hidden_size,
466
+ cross_attention_dim=self.processor.cross_attention_dim,
467
+ num_tokens=self.processor.num_tokens,
468
+ scale=self.processor.scale,
469
+ )
470
+ processor.load_state_dict(self.processor.state_dict())
471
+ if hasattr(self.processor, "to_k_ip"):
472
+ processor.to(
473
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
474
+ )
382
475
  else:
383
476
  # set attention processor
384
477
  # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -482,7 +575,7 @@ class Attention(nn.Module):
482
575
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
483
576
 
484
577
  attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
485
- quiet_attn_parameters = {"ip_adapter_masks"}
578
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
486
579
  unused_kwargs = [
487
580
  k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
488
581
  ]
@@ -697,7 +790,11 @@ class Attention(nn.Module):
697
790
  self.to_kv.bias.copy_(concatenated_bias)
698
791
 
699
792
  # handle added projections for SD3 and others.
700
- if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
793
+ if (
794
+ getattr(self, "add_q_proj", None) is not None
795
+ and getattr(self, "add_k_proj", None) is not None
796
+ and getattr(self, "add_v_proj", None) is not None
797
+ ):
701
798
  concatenated_weights = torch.cat(
702
799
  [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
703
800
  )
@@ -717,6 +814,269 @@ class Attention(nn.Module):
717
814
  self.fused_projections = fuse
718
815
 
719
816
 
817
+ class SanaMultiscaleAttentionProjection(nn.Module):
818
+ def __init__(
819
+ self,
820
+ in_channels: int,
821
+ num_attention_heads: int,
822
+ kernel_size: int,
823
+ ) -> None:
824
+ super().__init__()
825
+
826
+ channels = 3 * in_channels
827
+ self.proj_in = nn.Conv2d(
828
+ channels,
829
+ channels,
830
+ kernel_size,
831
+ padding=kernel_size // 2,
832
+ groups=channels,
833
+ bias=False,
834
+ )
835
+ self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
836
+
837
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
838
+ hidden_states = self.proj_in(hidden_states)
839
+ hidden_states = self.proj_out(hidden_states)
840
+ return hidden_states
841
+
842
+
843
+ class SanaMultiscaleLinearAttention(nn.Module):
844
+ r"""Lightweight multi-scale linear attention"""
845
+
846
+ def __init__(
847
+ self,
848
+ in_channels: int,
849
+ out_channels: int,
850
+ num_attention_heads: Optional[int] = None,
851
+ attention_head_dim: int = 8,
852
+ mult: float = 1.0,
853
+ norm_type: str = "batch_norm",
854
+ kernel_sizes: Tuple[int, ...] = (5,),
855
+ eps: float = 1e-15,
856
+ residual_connection: bool = False,
857
+ ):
858
+ super().__init__()
859
+
860
+ # To prevent circular import
861
+ from .normalization import get_normalization
862
+
863
+ self.eps = eps
864
+ self.attention_head_dim = attention_head_dim
865
+ self.norm_type = norm_type
866
+ self.residual_connection = residual_connection
867
+
868
+ num_attention_heads = (
869
+ int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
870
+ )
871
+ inner_dim = num_attention_heads * attention_head_dim
872
+
873
+ self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
874
+ self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
875
+ self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
876
+
877
+ self.to_qkv_multiscale = nn.ModuleList()
878
+ for kernel_size in kernel_sizes:
879
+ self.to_qkv_multiscale.append(
880
+ SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
881
+ )
882
+
883
+ self.nonlinearity = nn.ReLU()
884
+ self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
885
+ self.norm_out = get_normalization(norm_type, num_features=out_channels)
886
+
887
+ self.processor = SanaMultiscaleAttnProcessor2_0()
888
+
889
+ def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
890
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
891
+ scores = torch.matmul(value, key.transpose(-1, -2))
892
+ hidden_states = torch.matmul(scores, query)
893
+
894
+ hidden_states = hidden_states.to(dtype=torch.float32)
895
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
896
+ return hidden_states
897
+
898
+ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
899
+ scores = torch.matmul(key.transpose(-1, -2), query)
900
+ scores = scores.to(dtype=torch.float32)
901
+ scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
902
+ hidden_states = torch.matmul(value, scores)
903
+ return hidden_states
904
+
905
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906
+ return self.processor(self, hidden_states)
907
+
908
+
909
+ class MochiAttention(nn.Module):
910
+ def __init__(
911
+ self,
912
+ query_dim: int,
913
+ added_kv_proj_dim: int,
914
+ processor: "MochiAttnProcessor2_0",
915
+ heads: int = 8,
916
+ dim_head: int = 64,
917
+ dropout: float = 0.0,
918
+ bias: bool = False,
919
+ added_proj_bias: bool = True,
920
+ out_dim: Optional[int] = None,
921
+ out_context_dim: Optional[int] = None,
922
+ out_bias: bool = True,
923
+ context_pre_only: bool = False,
924
+ eps: float = 1e-5,
925
+ ):
926
+ super().__init__()
927
+ from .normalization import MochiRMSNorm
928
+
929
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
930
+ self.out_dim = out_dim if out_dim is not None else query_dim
931
+ self.out_context_dim = out_context_dim if out_context_dim else query_dim
932
+ self.context_pre_only = context_pre_only
933
+
934
+ self.heads = out_dim // dim_head if out_dim is not None else heads
935
+
936
+ self.norm_q = MochiRMSNorm(dim_head, eps, True)
937
+ self.norm_k = MochiRMSNorm(dim_head, eps, True)
938
+ self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
939
+ self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
940
+
941
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
942
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
943
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
944
+
945
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
946
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
947
+ if self.context_pre_only is not None:
948
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
949
+
950
+ self.to_out = nn.ModuleList([])
951
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
952
+ self.to_out.append(nn.Dropout(dropout))
953
+
954
+ if not self.context_pre_only:
955
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
956
+
957
+ self.processor = processor
958
+
959
+ def forward(
960
+ self,
961
+ hidden_states: torch.Tensor,
962
+ encoder_hidden_states: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.Tensor] = None,
964
+ **kwargs,
965
+ ):
966
+ return self.processor(
967
+ self,
968
+ hidden_states,
969
+ encoder_hidden_states=encoder_hidden_states,
970
+ attention_mask=attention_mask,
971
+ **kwargs,
972
+ )
973
+
974
+
975
+ class MochiAttnProcessor2_0:
976
+ """Attention processor used in Mochi."""
977
+
978
+ def __init__(self):
979
+ if not hasattr(F, "scaled_dot_product_attention"):
980
+ raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
981
+
982
+ def __call__(
983
+ self,
984
+ attn: "MochiAttention",
985
+ hidden_states: torch.Tensor,
986
+ encoder_hidden_states: torch.Tensor,
987
+ attention_mask: torch.Tensor,
988
+ image_rotary_emb: Optional[torch.Tensor] = None,
989
+ ) -> torch.Tensor:
990
+ query = attn.to_q(hidden_states)
991
+ key = attn.to_k(hidden_states)
992
+ value = attn.to_v(hidden_states)
993
+
994
+ query = query.unflatten(2, (attn.heads, -1))
995
+ key = key.unflatten(2, (attn.heads, -1))
996
+ value = value.unflatten(2, (attn.heads, -1))
997
+
998
+ if attn.norm_q is not None:
999
+ query = attn.norm_q(query)
1000
+ if attn.norm_k is not None:
1001
+ key = attn.norm_k(key)
1002
+
1003
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
1004
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
1005
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
1006
+
1007
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
1008
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
1009
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
1010
+
1011
+ if attn.norm_added_q is not None:
1012
+ encoder_query = attn.norm_added_q(encoder_query)
1013
+ if attn.norm_added_k is not None:
1014
+ encoder_key = attn.norm_added_k(encoder_key)
1015
+
1016
+ if image_rotary_emb is not None:
1017
+
1018
+ def apply_rotary_emb(x, freqs_cos, freqs_sin):
1019
+ x_even = x[..., 0::2].float()
1020
+ x_odd = x[..., 1::2].float()
1021
+
1022
+ cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
1023
+ sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
1024
+
1025
+ return torch.stack([cos, sin], dim=-1).flatten(-2)
1026
+
1027
+ query = apply_rotary_emb(query, *image_rotary_emb)
1028
+ key = apply_rotary_emb(key, *image_rotary_emb)
1029
+
1030
+ query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
1031
+ encoder_query, encoder_key, encoder_value = (
1032
+ encoder_query.transpose(1, 2),
1033
+ encoder_key.transpose(1, 2),
1034
+ encoder_value.transpose(1, 2),
1035
+ )
1036
+
1037
+ sequence_length = query.size(2)
1038
+ encoder_sequence_length = encoder_query.size(2)
1039
+ total_length = sequence_length + encoder_sequence_length
1040
+
1041
+ batch_size, heads, _, dim = query.shape
1042
+ attn_outputs = []
1043
+ for idx in range(batch_size):
1044
+ mask = attention_mask[idx][None, :]
1045
+ valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
1046
+
1047
+ valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
1048
+ valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
1049
+ valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
1050
+
1051
+ valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
1052
+ valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
1053
+ valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
1054
+
1055
+ attn_output = F.scaled_dot_product_attention(
1056
+ valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
1057
+ )
1058
+ valid_sequence_length = attn_output.size(2)
1059
+ attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
1060
+ attn_outputs.append(attn_output)
1061
+
1062
+ hidden_states = torch.cat(attn_outputs, dim=0)
1063
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
1064
+
1065
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
1066
+ (sequence_length, encoder_sequence_length), dim=1
1067
+ )
1068
+
1069
+ # linear proj
1070
+ hidden_states = attn.to_out[0](hidden_states)
1071
+ # dropout
1072
+ hidden_states = attn.to_out[1](hidden_states)
1073
+
1074
+ if hasattr(attn, "to_add_out"):
1075
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1076
+
1077
+ return hidden_states, encoder_hidden_states
1078
+
1079
+
720
1080
  class AttnProcessor:
721
1081
  r"""
722
1082
  Default processor for performing attention-related computations.
@@ -1136,6 +1496,7 @@ class PAGJointAttnProcessor2_0:
1136
1496
  attn: Attention,
1137
1497
  hidden_states: torch.FloatTensor,
1138
1498
  encoder_hidden_states: torch.FloatTensor = None,
1499
+ attention_mask: Optional[torch.FloatTensor] = None,
1139
1500
  ) -> torch.FloatTensor:
1140
1501
  residual = hidden_states
1141
1502
 
@@ -1521,92 +1882,84 @@ class FusedJointAttnProcessor2_0:
1521
1882
  return hidden_states, encoder_hidden_states
1522
1883
 
1523
1884
 
1524
- class AuraFlowAttnProcessor2_0:
1525
- """Attention processor used typically in processing Aura Flow."""
1885
+ class XFormersJointAttnProcessor:
1886
+ r"""
1887
+ Processor for implementing memory efficient attention using xFormers.
1526
1888
 
1527
- def __init__(self):
1528
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1529
- raise ImportError(
1530
- "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1531
- )
1889
+ Args:
1890
+ attention_op (`Callable`, *optional*, defaults to `None`):
1891
+ The base
1892
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1893
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1894
+ operator.
1895
+ """
1896
+
1897
+ def __init__(self, attention_op: Optional[Callable] = None):
1898
+ self.attention_op = attention_op
1532
1899
 
1533
1900
  def __call__(
1534
1901
  self,
1535
1902
  attn: Attention,
1536
1903
  hidden_states: torch.FloatTensor,
1537
1904
  encoder_hidden_states: torch.FloatTensor = None,
1905
+ attention_mask: Optional[torch.FloatTensor] = None,
1538
1906
  *args,
1539
1907
  **kwargs,
1540
1908
  ) -> torch.FloatTensor:
1541
- batch_size = hidden_states.shape[0]
1909
+ residual = hidden_states
1542
1910
 
1543
1911
  # `sample` projections.
1544
1912
  query = attn.to_q(hidden_states)
1545
1913
  key = attn.to_k(hidden_states)
1546
1914
  value = attn.to_v(hidden_states)
1547
1915
 
1548
- # `context` projections.
1549
- if encoder_hidden_states is not None:
1550
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1551
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1552
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1553
-
1554
- # Reshape.
1555
- inner_dim = key.shape[-1]
1556
- head_dim = inner_dim // attn.heads
1557
- query = query.view(batch_size, -1, attn.heads, head_dim)
1558
- key = key.view(batch_size, -1, attn.heads, head_dim)
1559
- value = value.view(batch_size, -1, attn.heads, head_dim)
1916
+ query = attn.head_to_batch_dim(query).contiguous()
1917
+ key = attn.head_to_batch_dim(key).contiguous()
1918
+ value = attn.head_to_batch_dim(value).contiguous()
1560
1919
 
1561
- # Apply QK norm.
1562
1920
  if attn.norm_q is not None:
1563
1921
  query = attn.norm_q(query)
1564
1922
  if attn.norm_k is not None:
1565
1923
  key = attn.norm_k(key)
1566
1924
 
1567
- # Concatenate the projections.
1925
+ # `context` projections.
1568
1926
  if encoder_hidden_states is not None:
1569
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1570
- batch_size, -1, attn.heads, head_dim
1571
- )
1572
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
1573
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1574
- batch_size, -1, attn.heads, head_dim
1575
- )
1927
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1928
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1929
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1930
+
1931
+ encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
1932
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
1933
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
1576
1934
 
1577
1935
  if attn.norm_added_q is not None:
1578
1936
  encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1579
1937
  if attn.norm_added_k is not None:
1580
- encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
1581
-
1582
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
1583
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1584
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1938
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1585
1939
 
1586
- query = query.transpose(1, 2)
1587
- key = key.transpose(1, 2)
1588
- value = value.transpose(1, 2)
1940
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1941
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1942
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1589
1943
 
1590
- # Attention.
1591
- hidden_states = F.scaled_dot_product_attention(
1592
- query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
1944
+ hidden_states = xformers.ops.memory_efficient_attention(
1945
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1593
1946
  )
1594
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1595
1947
  hidden_states = hidden_states.to(query.dtype)
1948
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1596
1949
 
1597
- # Split the attention outputs.
1598
1950
  if encoder_hidden_states is not None:
1951
+ # Split the attention outputs.
1599
1952
  hidden_states, encoder_hidden_states = (
1600
- hidden_states[:, encoder_hidden_states.shape[1] :],
1601
- hidden_states[:, : encoder_hidden_states.shape[1]],
1953
+ hidden_states[:, : residual.shape[1]],
1954
+ hidden_states[:, residual.shape[1] :],
1602
1955
  )
1956
+ if not attn.context_pre_only:
1957
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1603
1958
 
1604
1959
  # linear proj
1605
1960
  hidden_states = attn.to_out[0](hidden_states)
1606
1961
  # dropout
1607
1962
  hidden_states = attn.to_out[1](hidden_states)
1608
- if encoder_hidden_states is not None:
1609
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1610
1963
 
1611
1964
  if encoder_hidden_states is not None:
1612
1965
  return hidden_states, encoder_hidden_states
@@ -1614,27 +1967,214 @@ class AuraFlowAttnProcessor2_0:
1614
1967
  return hidden_states
1615
1968
 
1616
1969
 
1617
- class FusedAuraFlowAttnProcessor2_0:
1618
- """Attention processor used typically in processing Aura Flow with fused projections."""
1970
+ class AllegroAttnProcessor2_0:
1971
+ r"""
1972
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1973
+ used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
1974
+ """
1619
1975
 
1620
1976
  def __init__(self):
1621
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1977
+ if not hasattr(F, "scaled_dot_product_attention"):
1622
1978
  raise ImportError(
1623
- "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1979
+ "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1624
1980
  )
1625
1981
 
1626
1982
  def __call__(
1627
1983
  self,
1628
1984
  attn: Attention,
1629
- hidden_states: torch.FloatTensor,
1630
- encoder_hidden_states: torch.FloatTensor = None,
1631
- *args,
1632
- **kwargs,
1633
- ) -> torch.FloatTensor:
1634
- batch_size = hidden_states.shape[0]
1635
-
1636
- # `sample` projections.
1637
- qkv = attn.to_qkv(hidden_states)
1985
+ hidden_states: torch.Tensor,
1986
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1987
+ attention_mask: Optional[torch.Tensor] = None,
1988
+ temb: Optional[torch.Tensor] = None,
1989
+ image_rotary_emb: Optional[torch.Tensor] = None,
1990
+ ) -> torch.Tensor:
1991
+ residual = hidden_states
1992
+
1993
+ if attn.spatial_norm is not None:
1994
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1995
+
1996
+ input_ndim = hidden_states.ndim
1997
+
1998
+ if input_ndim == 4:
1999
+ batch_size, channel, height, width = hidden_states.shape
2000
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2001
+
2002
+ batch_size, sequence_length, _ = (
2003
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2004
+ )
2005
+
2006
+ if attention_mask is not None:
2007
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2008
+ # scaled_dot_product_attention expects attention_mask shape to be
2009
+ # (batch, heads, source_length, target_length)
2010
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2011
+
2012
+ if attn.group_norm is not None:
2013
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2014
+
2015
+ query = attn.to_q(hidden_states)
2016
+
2017
+ if encoder_hidden_states is None:
2018
+ encoder_hidden_states = hidden_states
2019
+ elif attn.norm_cross:
2020
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2021
+
2022
+ key = attn.to_k(encoder_hidden_states)
2023
+ value = attn.to_v(encoder_hidden_states)
2024
+
2025
+ inner_dim = key.shape[-1]
2026
+ head_dim = inner_dim // attn.heads
2027
+
2028
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2029
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2030
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2031
+
2032
+ # Apply RoPE if needed
2033
+ if image_rotary_emb is not None and not attn.is_cross_attention:
2034
+ from .embeddings import apply_rotary_emb_allegro
2035
+
2036
+ query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
2037
+ key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
2038
+
2039
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2040
+ # TODO: add support for attn.scale when we move to Torch 2.1
2041
+ hidden_states = F.scaled_dot_product_attention(
2042
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2043
+ )
2044
+
2045
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2046
+ hidden_states = hidden_states.to(query.dtype)
2047
+
2048
+ # linear proj
2049
+ hidden_states = attn.to_out[0](hidden_states)
2050
+ # dropout
2051
+ hidden_states = attn.to_out[1](hidden_states)
2052
+
2053
+ if input_ndim == 4:
2054
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2055
+
2056
+ if attn.residual_connection:
2057
+ hidden_states = hidden_states + residual
2058
+
2059
+ hidden_states = hidden_states / attn.rescale_output_factor
2060
+
2061
+ return hidden_states
2062
+
2063
+
2064
+ class AuraFlowAttnProcessor2_0:
2065
+ """Attention processor used typically in processing Aura Flow."""
2066
+
2067
+ def __init__(self):
2068
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
2069
+ raise ImportError(
2070
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
2071
+ )
2072
+
2073
+ def __call__(
2074
+ self,
2075
+ attn: Attention,
2076
+ hidden_states: torch.FloatTensor,
2077
+ encoder_hidden_states: torch.FloatTensor = None,
2078
+ *args,
2079
+ **kwargs,
2080
+ ) -> torch.FloatTensor:
2081
+ batch_size = hidden_states.shape[0]
2082
+
2083
+ # `sample` projections.
2084
+ query = attn.to_q(hidden_states)
2085
+ key = attn.to_k(hidden_states)
2086
+ value = attn.to_v(hidden_states)
2087
+
2088
+ # `context` projections.
2089
+ if encoder_hidden_states is not None:
2090
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2091
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2092
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2093
+
2094
+ # Reshape.
2095
+ inner_dim = key.shape[-1]
2096
+ head_dim = inner_dim // attn.heads
2097
+ query = query.view(batch_size, -1, attn.heads, head_dim)
2098
+ key = key.view(batch_size, -1, attn.heads, head_dim)
2099
+ value = value.view(batch_size, -1, attn.heads, head_dim)
2100
+
2101
+ # Apply QK norm.
2102
+ if attn.norm_q is not None:
2103
+ query = attn.norm_q(query)
2104
+ if attn.norm_k is not None:
2105
+ key = attn.norm_k(key)
2106
+
2107
+ # Concatenate the projections.
2108
+ if encoder_hidden_states is not None:
2109
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2110
+ batch_size, -1, attn.heads, head_dim
2111
+ )
2112
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
2113
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2114
+ batch_size, -1, attn.heads, head_dim
2115
+ )
2116
+
2117
+ if attn.norm_added_q is not None:
2118
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2119
+ if attn.norm_added_k is not None:
2120
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
2121
+
2122
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
2123
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
2124
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
2125
+
2126
+ query = query.transpose(1, 2)
2127
+ key = key.transpose(1, 2)
2128
+ value = value.transpose(1, 2)
2129
+
2130
+ # Attention.
2131
+ hidden_states = F.scaled_dot_product_attention(
2132
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
2133
+ )
2134
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2135
+ hidden_states = hidden_states.to(query.dtype)
2136
+
2137
+ # Split the attention outputs.
2138
+ if encoder_hidden_states is not None:
2139
+ hidden_states, encoder_hidden_states = (
2140
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2141
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2142
+ )
2143
+
2144
+ # linear proj
2145
+ hidden_states = attn.to_out[0](hidden_states)
2146
+ # dropout
2147
+ hidden_states = attn.to_out[1](hidden_states)
2148
+ if encoder_hidden_states is not None:
2149
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2150
+
2151
+ if encoder_hidden_states is not None:
2152
+ return hidden_states, encoder_hidden_states
2153
+ else:
2154
+ return hidden_states
2155
+
2156
+
2157
+ class FusedAuraFlowAttnProcessor2_0:
2158
+ """Attention processor used typically in processing Aura Flow with fused projections."""
2159
+
2160
+ def __init__(self):
2161
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
2162
+ raise ImportError(
2163
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
2164
+ )
2165
+
2166
+ def __call__(
2167
+ self,
2168
+ attn: Attention,
2169
+ hidden_states: torch.FloatTensor,
2170
+ encoder_hidden_states: torch.FloatTensor = None,
2171
+ *args,
2172
+ **kwargs,
2173
+ ) -> torch.FloatTensor:
2174
+ batch_size = hidden_states.shape[0]
2175
+
2176
+ # `sample` projections.
2177
+ qkv = attn.to_qkv(hidden_states)
1638
2178
  split_size = qkv.shape[-1] // 3
1639
2179
  query, key, value = torch.split(qkv, split_size, dim=-1)
1640
2180
 
@@ -1778,7 +2318,321 @@ class FluxAttnProcessor2_0:
1778
2318
  query = apply_rotary_emb(query, image_rotary_emb)
1779
2319
  key = apply_rotary_emb(key, image_rotary_emb)
1780
2320
 
1781
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2321
+ hidden_states = F.scaled_dot_product_attention(
2322
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2323
+ )
2324
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2325
+ hidden_states = hidden_states.to(query.dtype)
2326
+
2327
+ if encoder_hidden_states is not None:
2328
+ encoder_hidden_states, hidden_states = (
2329
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2330
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2331
+ )
2332
+
2333
+ # linear proj
2334
+ hidden_states = attn.to_out[0](hidden_states)
2335
+ # dropout
2336
+ hidden_states = attn.to_out[1](hidden_states)
2337
+
2338
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2339
+
2340
+ return hidden_states, encoder_hidden_states
2341
+ else:
2342
+ return hidden_states
2343
+
2344
+
2345
+ class FluxAttnProcessor2_0_NPU:
2346
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2347
+
2348
+ def __init__(self):
2349
+ if not hasattr(F, "scaled_dot_product_attention"):
2350
+ raise ImportError(
2351
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
2352
+ )
2353
+
2354
+ def __call__(
2355
+ self,
2356
+ attn: Attention,
2357
+ hidden_states: torch.FloatTensor,
2358
+ encoder_hidden_states: torch.FloatTensor = None,
2359
+ attention_mask: Optional[torch.FloatTensor] = None,
2360
+ image_rotary_emb: Optional[torch.Tensor] = None,
2361
+ ) -> torch.FloatTensor:
2362
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2363
+
2364
+ # `sample` projections.
2365
+ query = attn.to_q(hidden_states)
2366
+ key = attn.to_k(hidden_states)
2367
+ value = attn.to_v(hidden_states)
2368
+
2369
+ inner_dim = key.shape[-1]
2370
+ head_dim = inner_dim // attn.heads
2371
+
2372
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2373
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2374
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2375
+
2376
+ if attn.norm_q is not None:
2377
+ query = attn.norm_q(query)
2378
+ if attn.norm_k is not None:
2379
+ key = attn.norm_k(key)
2380
+
2381
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2382
+ if encoder_hidden_states is not None:
2383
+ # `context` projections.
2384
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2385
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2386
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2387
+
2388
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2389
+ batch_size, -1, attn.heads, head_dim
2390
+ ).transpose(1, 2)
2391
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2392
+ batch_size, -1, attn.heads, head_dim
2393
+ ).transpose(1, 2)
2394
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2395
+ batch_size, -1, attn.heads, head_dim
2396
+ ).transpose(1, 2)
2397
+
2398
+ if attn.norm_added_q is not None:
2399
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2400
+ if attn.norm_added_k is not None:
2401
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2402
+
2403
+ # attention
2404
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2405
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2406
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2407
+
2408
+ if image_rotary_emb is not None:
2409
+ from .embeddings import apply_rotary_emb
2410
+
2411
+ query = apply_rotary_emb(query, image_rotary_emb)
2412
+ key = apply_rotary_emb(key, image_rotary_emb)
2413
+
2414
+ if query.dtype in (torch.float16, torch.bfloat16):
2415
+ hidden_states = torch_npu.npu_fusion_attention(
2416
+ query,
2417
+ key,
2418
+ value,
2419
+ attn.heads,
2420
+ input_layout="BNSD",
2421
+ pse=None,
2422
+ scale=1.0 / math.sqrt(query.shape[-1]),
2423
+ pre_tockens=65536,
2424
+ next_tockens=65536,
2425
+ keep_prob=1.0,
2426
+ sync=False,
2427
+ inner_precise=0,
2428
+ )[0]
2429
+ else:
2430
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2431
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2432
+ hidden_states = hidden_states.to(query.dtype)
2433
+
2434
+ if encoder_hidden_states is not None:
2435
+ encoder_hidden_states, hidden_states = (
2436
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2437
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2438
+ )
2439
+
2440
+ # linear proj
2441
+ hidden_states = attn.to_out[0](hidden_states)
2442
+ # dropout
2443
+ hidden_states = attn.to_out[1](hidden_states)
2444
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2445
+
2446
+ return hidden_states, encoder_hidden_states
2447
+ else:
2448
+ return hidden_states
2449
+
2450
+
2451
+ class FusedFluxAttnProcessor2_0:
2452
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2453
+
2454
+ def __init__(self):
2455
+ if not hasattr(F, "scaled_dot_product_attention"):
2456
+ raise ImportError(
2457
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2458
+ )
2459
+
2460
+ def __call__(
2461
+ self,
2462
+ attn: Attention,
2463
+ hidden_states: torch.FloatTensor,
2464
+ encoder_hidden_states: torch.FloatTensor = None,
2465
+ attention_mask: Optional[torch.FloatTensor] = None,
2466
+ image_rotary_emb: Optional[torch.Tensor] = None,
2467
+ ) -> torch.FloatTensor:
2468
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2469
+
2470
+ # `sample` projections.
2471
+ qkv = attn.to_qkv(hidden_states)
2472
+ split_size = qkv.shape[-1] // 3
2473
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2474
+
2475
+ inner_dim = key.shape[-1]
2476
+ head_dim = inner_dim // attn.heads
2477
+
2478
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2479
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2480
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2481
+
2482
+ if attn.norm_q is not None:
2483
+ query = attn.norm_q(query)
2484
+ if attn.norm_k is not None:
2485
+ key = attn.norm_k(key)
2486
+
2487
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2488
+ # `context` projections.
2489
+ if encoder_hidden_states is not None:
2490
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2491
+ split_size = encoder_qkv.shape[-1] // 3
2492
+ (
2493
+ encoder_hidden_states_query_proj,
2494
+ encoder_hidden_states_key_proj,
2495
+ encoder_hidden_states_value_proj,
2496
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
2497
+
2498
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2499
+ batch_size, -1, attn.heads, head_dim
2500
+ ).transpose(1, 2)
2501
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2502
+ batch_size, -1, attn.heads, head_dim
2503
+ ).transpose(1, 2)
2504
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2505
+ batch_size, -1, attn.heads, head_dim
2506
+ ).transpose(1, 2)
2507
+
2508
+ if attn.norm_added_q is not None:
2509
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2510
+ if attn.norm_added_k is not None:
2511
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2512
+
2513
+ # attention
2514
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2515
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2516
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2517
+
2518
+ if image_rotary_emb is not None:
2519
+ from .embeddings import apply_rotary_emb
2520
+
2521
+ query = apply_rotary_emb(query, image_rotary_emb)
2522
+ key = apply_rotary_emb(key, image_rotary_emb)
2523
+
2524
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2525
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2526
+ hidden_states = hidden_states.to(query.dtype)
2527
+
2528
+ if encoder_hidden_states is not None:
2529
+ encoder_hidden_states, hidden_states = (
2530
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2531
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2532
+ )
2533
+
2534
+ # linear proj
2535
+ hidden_states = attn.to_out[0](hidden_states)
2536
+ # dropout
2537
+ hidden_states = attn.to_out[1](hidden_states)
2538
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2539
+
2540
+ return hidden_states, encoder_hidden_states
2541
+ else:
2542
+ return hidden_states
2543
+
2544
+
2545
+ class FusedFluxAttnProcessor2_0_NPU:
2546
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2547
+
2548
+ def __init__(self):
2549
+ if not hasattr(F, "scaled_dot_product_attention"):
2550
+ raise ImportError(
2551
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2552
+ )
2553
+
2554
+ def __call__(
2555
+ self,
2556
+ attn: Attention,
2557
+ hidden_states: torch.FloatTensor,
2558
+ encoder_hidden_states: torch.FloatTensor = None,
2559
+ attention_mask: Optional[torch.FloatTensor] = None,
2560
+ image_rotary_emb: Optional[torch.Tensor] = None,
2561
+ ) -> torch.FloatTensor:
2562
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2563
+
2564
+ # `sample` projections.
2565
+ qkv = attn.to_qkv(hidden_states)
2566
+ split_size = qkv.shape[-1] // 3
2567
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2568
+
2569
+ inner_dim = key.shape[-1]
2570
+ head_dim = inner_dim // attn.heads
2571
+
2572
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2573
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2574
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2575
+
2576
+ if attn.norm_q is not None:
2577
+ query = attn.norm_q(query)
2578
+ if attn.norm_k is not None:
2579
+ key = attn.norm_k(key)
2580
+
2581
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2582
+ # `context` projections.
2583
+ if encoder_hidden_states is not None:
2584
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2585
+ split_size = encoder_qkv.shape[-1] // 3
2586
+ (
2587
+ encoder_hidden_states_query_proj,
2588
+ encoder_hidden_states_key_proj,
2589
+ encoder_hidden_states_value_proj,
2590
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
2591
+
2592
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2593
+ batch_size, -1, attn.heads, head_dim
2594
+ ).transpose(1, 2)
2595
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2596
+ batch_size, -1, attn.heads, head_dim
2597
+ ).transpose(1, 2)
2598
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2599
+ batch_size, -1, attn.heads, head_dim
2600
+ ).transpose(1, 2)
2601
+
2602
+ if attn.norm_added_q is not None:
2603
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2604
+ if attn.norm_added_k is not None:
2605
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2606
+
2607
+ # attention
2608
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2609
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2610
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2611
+
2612
+ if image_rotary_emb is not None:
2613
+ from .embeddings import apply_rotary_emb
2614
+
2615
+ query = apply_rotary_emb(query, image_rotary_emb)
2616
+ key = apply_rotary_emb(key, image_rotary_emb)
2617
+
2618
+ if query.dtype in (torch.float16, torch.bfloat16):
2619
+ hidden_states = torch_npu.npu_fusion_attention(
2620
+ query,
2621
+ key,
2622
+ value,
2623
+ attn.heads,
2624
+ input_layout="BNSD",
2625
+ pse=None,
2626
+ scale=1.0 / math.sqrt(query.shape[-1]),
2627
+ pre_tockens=65536,
2628
+ next_tockens=65536,
2629
+ keep_prob=1.0,
2630
+ sync=False,
2631
+ inner_precise=0,
2632
+ )[0]
2633
+ else:
2634
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2635
+
1782
2636
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1783
2637
  hidden_states = hidden_states.to(query.dtype)
1784
2638
 
@@ -1799,15 +2653,44 @@ class FluxAttnProcessor2_0:
1799
2653
  return hidden_states
1800
2654
 
1801
2655
 
1802
- class FusedFluxAttnProcessor2_0:
1803
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2656
+ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2657
+ """Flux Attention processor for IP-Adapter."""
2658
+
2659
+ def __init__(
2660
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2661
+ ):
2662
+ super().__init__()
1804
2663
 
1805
- def __init__(self):
1806
2664
  if not hasattr(F, "scaled_dot_product_attention"):
1807
2665
  raise ImportError(
1808
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2666
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1809
2667
  )
1810
2668
 
2669
+ self.hidden_size = hidden_size
2670
+ self.cross_attention_dim = cross_attention_dim
2671
+
2672
+ if not isinstance(num_tokens, (tuple, list)):
2673
+ num_tokens = [num_tokens]
2674
+
2675
+ if not isinstance(scale, list):
2676
+ scale = [scale] * len(num_tokens)
2677
+ if len(scale) != len(num_tokens):
2678
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2679
+ self.scale = scale
2680
+
2681
+ self.to_k_ip = nn.ModuleList(
2682
+ [
2683
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2684
+ for _ in range(len(num_tokens))
2685
+ ]
2686
+ )
2687
+ self.to_v_ip = nn.ModuleList(
2688
+ [
2689
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2690
+ for _ in range(len(num_tokens))
2691
+ ]
2692
+ )
2693
+
1811
2694
  def __call__(
1812
2695
  self,
1813
2696
  attn: Attention,
@@ -1815,36 +2698,34 @@ class FusedFluxAttnProcessor2_0:
1815
2698
  encoder_hidden_states: torch.FloatTensor = None,
1816
2699
  attention_mask: Optional[torch.FloatTensor] = None,
1817
2700
  image_rotary_emb: Optional[torch.Tensor] = None,
2701
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
2702
+ ip_adapter_masks: Optional[torch.Tensor] = None,
1818
2703
  ) -> torch.FloatTensor:
1819
2704
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1820
2705
 
1821
2706
  # `sample` projections.
1822
- qkv = attn.to_qkv(hidden_states)
1823
- split_size = qkv.shape[-1] // 3
1824
- query, key, value = torch.split(qkv, split_size, dim=-1)
2707
+ hidden_states_query_proj = attn.to_q(hidden_states)
2708
+ key = attn.to_k(hidden_states)
2709
+ value = attn.to_v(hidden_states)
1825
2710
 
1826
2711
  inner_dim = key.shape[-1]
1827
2712
  head_dim = inner_dim // attn.heads
1828
2713
 
1829
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2714
+ hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1830
2715
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1831
2716
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1832
2717
 
1833
2718
  if attn.norm_q is not None:
1834
- query = attn.norm_q(query)
2719
+ hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
1835
2720
  if attn.norm_k is not None:
1836
2721
  key = attn.norm_k(key)
1837
2722
 
1838
2723
  # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1839
- # `context` projections.
1840
2724
  if encoder_hidden_states is not None:
1841
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1842
- split_size = encoder_qkv.shape[-1] // 3
1843
- (
1844
- encoder_hidden_states_query_proj,
1845
- encoder_hidden_states_key_proj,
1846
- encoder_hidden_states_value_proj,
1847
- ) = torch.split(encoder_qkv, split_size, dim=-1)
2725
+ # `context` projections.
2726
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2727
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2728
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1848
2729
 
1849
2730
  encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1850
2731
  batch_size, -1, attn.heads, head_dim
@@ -1862,7 +2743,7 @@ class FusedFluxAttnProcessor2_0:
1862
2743
  encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1863
2744
 
1864
2745
  # attention
1865
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2746
+ query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
1866
2747
  key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1867
2748
  value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1868
2749
 
@@ -1888,7 +2769,29 @@ class FusedFluxAttnProcessor2_0:
1888
2769
  hidden_states = attn.to_out[1](hidden_states)
1889
2770
  encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1890
2771
 
1891
- return hidden_states, encoder_hidden_states
2772
+ # IP-adapter
2773
+ ip_query = hidden_states_query_proj
2774
+ ip_attn_output = None
2775
+ # for ip-adapter
2776
+ # TODO: support for multiple adapters
2777
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2778
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2779
+ ):
2780
+ ip_key = to_k_ip(current_ip_hidden_states)
2781
+ ip_value = to_v_ip(current_ip_hidden_states)
2782
+
2783
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2784
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2785
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2786
+ # TODO: add support for attn.scale when we move to Torch 2.1
2787
+ ip_attn_output = F.scaled_dot_product_attention(
2788
+ ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2789
+ )
2790
+ ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2791
+ ip_attn_output = scale * ip_attn_output
2792
+ ip_attn_output = ip_attn_output.to(ip_query.dtype)
2793
+
2794
+ return hidden_states, encoder_hidden_states, ip_attn_output
1892
2795
  else:
1893
2796
  return hidden_states
1894
2797
 
@@ -2285,7 +3188,217 @@ class AttnProcessorNPU:
2285
3188
  inner_precise=0,
2286
3189
  )[0]
2287
3190
  else:
2288
- # TODO: add support for attn.scale when we move to Torch 2.1
3191
+ # TODO: add support for attn.scale when we move to Torch 2.1
3192
+ hidden_states = F.scaled_dot_product_attention(
3193
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3194
+ )
3195
+
3196
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3197
+ hidden_states = hidden_states.to(query.dtype)
3198
+
3199
+ # linear proj
3200
+ hidden_states = attn.to_out[0](hidden_states)
3201
+ # dropout
3202
+ hidden_states = attn.to_out[1](hidden_states)
3203
+
3204
+ if input_ndim == 4:
3205
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3206
+
3207
+ if attn.residual_connection:
3208
+ hidden_states = hidden_states + residual
3209
+
3210
+ hidden_states = hidden_states / attn.rescale_output_factor
3211
+
3212
+ return hidden_states
3213
+
3214
+
3215
+ class AttnProcessor2_0:
3216
+ r"""
3217
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3218
+ """
3219
+
3220
+ def __init__(self):
3221
+ if not hasattr(F, "scaled_dot_product_attention"):
3222
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
3223
+
3224
+ def __call__(
3225
+ self,
3226
+ attn: Attention,
3227
+ hidden_states: torch.Tensor,
3228
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3229
+ attention_mask: Optional[torch.Tensor] = None,
3230
+ temb: Optional[torch.Tensor] = None,
3231
+ *args,
3232
+ **kwargs,
3233
+ ) -> torch.Tensor:
3234
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
3235
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3236
+ deprecate("scale", "1.0.0", deprecation_message)
3237
+
3238
+ residual = hidden_states
3239
+ if attn.spatial_norm is not None:
3240
+ hidden_states = attn.spatial_norm(hidden_states, temb)
3241
+
3242
+ input_ndim = hidden_states.ndim
3243
+
3244
+ if input_ndim == 4:
3245
+ batch_size, channel, height, width = hidden_states.shape
3246
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3247
+
3248
+ batch_size, sequence_length, _ = (
3249
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3250
+ )
3251
+
3252
+ if attention_mask is not None:
3253
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3254
+ # scaled_dot_product_attention expects attention_mask shape to be
3255
+ # (batch, heads, source_length, target_length)
3256
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3257
+
3258
+ if attn.group_norm is not None:
3259
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3260
+
3261
+ query = attn.to_q(hidden_states)
3262
+
3263
+ if encoder_hidden_states is None:
3264
+ encoder_hidden_states = hidden_states
3265
+ elif attn.norm_cross:
3266
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
3267
+
3268
+ key = attn.to_k(encoder_hidden_states)
3269
+ value = attn.to_v(encoder_hidden_states)
3270
+
3271
+ inner_dim = key.shape[-1]
3272
+ head_dim = inner_dim // attn.heads
3273
+
3274
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3275
+
3276
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3277
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3278
+
3279
+ if attn.norm_q is not None:
3280
+ query = attn.norm_q(query)
3281
+ if attn.norm_k is not None:
3282
+ key = attn.norm_k(key)
3283
+
3284
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3285
+ # TODO: add support for attn.scale when we move to Torch 2.1
3286
+ hidden_states = F.scaled_dot_product_attention(
3287
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3288
+ )
3289
+
3290
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3291
+ hidden_states = hidden_states.to(query.dtype)
3292
+
3293
+ # linear proj
3294
+ hidden_states = attn.to_out[0](hidden_states)
3295
+ # dropout
3296
+ hidden_states = attn.to_out[1](hidden_states)
3297
+
3298
+ if input_ndim == 4:
3299
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3300
+
3301
+ if attn.residual_connection:
3302
+ hidden_states = hidden_states + residual
3303
+
3304
+ hidden_states = hidden_states / attn.rescale_output_factor
3305
+
3306
+ return hidden_states
3307
+
3308
+
3309
+ class XLAFlashAttnProcessor2_0:
3310
+ r"""
3311
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3312
+ """
3313
+
3314
+ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3315
+ if not hasattr(F, "scaled_dot_product_attention"):
3316
+ raise ImportError(
3317
+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3318
+ )
3319
+ if is_torch_xla_version("<", "2.3"):
3320
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3321
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
3322
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3323
+ self.partition_spec = partition_spec
3324
+
3325
+ def __call__(
3326
+ self,
3327
+ attn: Attention,
3328
+ hidden_states: torch.Tensor,
3329
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3330
+ attention_mask: Optional[torch.Tensor] = None,
3331
+ temb: Optional[torch.Tensor] = None,
3332
+ *args,
3333
+ **kwargs,
3334
+ ) -> torch.Tensor:
3335
+ residual = hidden_states
3336
+ if attn.spatial_norm is not None:
3337
+ hidden_states = attn.spatial_norm(hidden_states, temb)
3338
+
3339
+ input_ndim = hidden_states.ndim
3340
+
3341
+ if input_ndim == 4:
3342
+ batch_size, channel, height, width = hidden_states.shape
3343
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3344
+
3345
+ batch_size, sequence_length, _ = (
3346
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3347
+ )
3348
+
3349
+ if attention_mask is not None:
3350
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3351
+ # scaled_dot_product_attention expects attention_mask shape to be
3352
+ # (batch, heads, source_length, target_length)
3353
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3354
+
3355
+ if attn.group_norm is not None:
3356
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3357
+
3358
+ query = attn.to_q(hidden_states)
3359
+
3360
+ if encoder_hidden_states is None:
3361
+ encoder_hidden_states = hidden_states
3362
+ elif attn.norm_cross:
3363
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
3364
+
3365
+ key = attn.to_k(encoder_hidden_states)
3366
+ value = attn.to_v(encoder_hidden_states)
3367
+
3368
+ inner_dim = key.shape[-1]
3369
+ head_dim = inner_dim // attn.heads
3370
+
3371
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3372
+
3373
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3374
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3375
+
3376
+ if attn.norm_q is not None:
3377
+ query = attn.norm_q(query)
3378
+ if attn.norm_k is not None:
3379
+ key = attn.norm_k(key)
3380
+
3381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3382
+ # TODO: add support for attn.scale when we move to Torch 2.1
3383
+ if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
3384
+ if attention_mask is not None:
3385
+ attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
3386
+ # Convert mask to float and replace 0s with -inf and 1s with 0
3387
+ attention_mask = (
3388
+ attention_mask.float()
3389
+ .masked_fill(attention_mask == 0, float("-inf"))
3390
+ .masked_fill(attention_mask == 1, float(0.0))
3391
+ )
3392
+
3393
+ # Apply attention mask to key
3394
+ key = key + attention_mask
3395
+ query /= math.sqrt(query.shape[3])
3396
+ partition_spec = self.partition_spec if is_spmd() else None
3397
+ hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
3398
+ else:
3399
+ logger.warning(
3400
+ "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
3401
+ )
2289
3402
  hidden_states = F.scaled_dot_product_attention(
2290
3403
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2291
3404
  )
@@ -2309,9 +3422,9 @@ class AttnProcessorNPU:
2309
3422
  return hidden_states
2310
3423
 
2311
3424
 
2312
- class AttnProcessor2_0:
3425
+ class MochiVaeAttnProcessor2_0:
2313
3426
  r"""
2314
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3427
+ Attention processor used in Mochi VAE.
2315
3428
  """
2316
3429
 
2317
3430
  def __init__(self):
@@ -2324,23 +3437,9 @@ class AttnProcessor2_0:
2324
3437
  hidden_states: torch.Tensor,
2325
3438
  encoder_hidden_states: Optional[torch.Tensor] = None,
2326
3439
  attention_mask: Optional[torch.Tensor] = None,
2327
- temb: Optional[torch.Tensor] = None,
2328
- *args,
2329
- **kwargs,
2330
3440
  ) -> torch.Tensor:
2331
- if len(args) > 0 or kwargs.get("scale", None) is not None:
2332
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2333
- deprecate("scale", "1.0.0", deprecation_message)
2334
-
2335
3441
  residual = hidden_states
2336
- if attn.spatial_norm is not None:
2337
- hidden_states = attn.spatial_norm(hidden_states, temb)
2338
-
2339
- input_ndim = hidden_states.ndim
2340
-
2341
- if input_ndim == 4:
2342
- batch_size, channel, height, width = hidden_states.shape
2343
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3442
+ is_single_frame = hidden_states.shape[1] == 1
2344
3443
 
2345
3444
  batch_size, sequence_length, _ = (
2346
3445
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
@@ -2352,15 +3451,24 @@ class AttnProcessor2_0:
2352
3451
  # (batch, heads, source_length, target_length)
2353
3452
  attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2354
3453
 
2355
- if attn.group_norm is not None:
2356
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3454
+ if is_single_frame:
3455
+ hidden_states = attn.to_v(hidden_states)
3456
+
3457
+ # linear proj
3458
+ hidden_states = attn.to_out[0](hidden_states)
3459
+ # dropout
3460
+ hidden_states = attn.to_out[1](hidden_states)
3461
+
3462
+ if attn.residual_connection:
3463
+ hidden_states = hidden_states + residual
3464
+
3465
+ hidden_states = hidden_states / attn.rescale_output_factor
3466
+ return hidden_states
2357
3467
 
2358
3468
  query = attn.to_q(hidden_states)
2359
3469
 
2360
3470
  if encoder_hidden_states is None:
2361
3471
  encoder_hidden_states = hidden_states
2362
- elif attn.norm_cross:
2363
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2364
3472
 
2365
3473
  key = attn.to_k(encoder_hidden_states)
2366
3474
  value = attn.to_v(encoder_hidden_states)
@@ -2369,7 +3477,6 @@ class AttnProcessor2_0:
2369
3477
  head_dim = inner_dim // attn.heads
2370
3478
 
2371
3479
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2372
-
2373
3480
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2374
3481
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2375
3482
 
@@ -2381,7 +3488,7 @@ class AttnProcessor2_0:
2381
3488
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
2382
3489
  # TODO: add support for attn.scale when we move to Torch 2.1
2383
3490
  hidden_states = F.scaled_dot_product_attention(
2384
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3491
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
2385
3492
  )
2386
3493
 
2387
3494
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -2392,9 +3499,6 @@ class AttnProcessor2_0:
2392
3499
  # dropout
2393
3500
  hidden_states = attn.to_out[1](hidden_states)
2394
3501
 
2395
- if input_ndim == 4:
2396
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2397
-
2398
3502
  if attn.residual_connection:
2399
3503
  hidden_states = hidden_states + residual
2400
3504
 
@@ -3597,34 +4701,232 @@ class SpatialNorm(nn.Module):
3597
4701
  """
3598
4702
  Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
3599
4703
 
3600
- Args:
3601
- f_channels (`int`):
3602
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
3603
- zq_channels (`int`):
3604
- The number of channels for the quantized vector as described in the paper.
3605
- """
4704
+ Args:
4705
+ f_channels (`int`):
4706
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
4707
+ zq_channels (`int`):
4708
+ The number of channels for the quantized vector as described in the paper.
4709
+ """
4710
+
4711
+ def __init__(
4712
+ self,
4713
+ f_channels: int,
4714
+ zq_channels: int,
4715
+ ):
4716
+ super().__init__()
4717
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
4718
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4719
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4720
+
4721
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
4722
+ f_size = f.shape[-2:]
4723
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
4724
+ norm_f = self.norm_layer(f)
4725
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
4726
+ return new_f
4727
+
4728
+
4729
+ class IPAdapterAttnProcessor(nn.Module):
4730
+ r"""
4731
+ Attention processor for Multiple IP-Adapters.
4732
+
4733
+ Args:
4734
+ hidden_size (`int`):
4735
+ The hidden size of the attention layer.
4736
+ cross_attention_dim (`int`):
4737
+ The number of channels in the `encoder_hidden_states`.
4738
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
4739
+ The context length of the image features.
4740
+ scale (`float` or List[`float`], defaults to 1.0):
4741
+ the weight scale of image prompt.
4742
+ """
4743
+
4744
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
4745
+ super().__init__()
4746
+
4747
+ self.hidden_size = hidden_size
4748
+ self.cross_attention_dim = cross_attention_dim
4749
+
4750
+ if not isinstance(num_tokens, (tuple, list)):
4751
+ num_tokens = [num_tokens]
4752
+ self.num_tokens = num_tokens
4753
+
4754
+ if not isinstance(scale, list):
4755
+ scale = [scale] * len(num_tokens)
4756
+ if len(scale) != len(num_tokens):
4757
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
4758
+ self.scale = scale
4759
+
4760
+ self.to_k_ip = nn.ModuleList(
4761
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
4762
+ )
4763
+ self.to_v_ip = nn.ModuleList(
4764
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
4765
+ )
4766
+
4767
+ def __call__(
4768
+ self,
4769
+ attn: Attention,
4770
+ hidden_states: torch.Tensor,
4771
+ encoder_hidden_states: Optional[torch.Tensor] = None,
4772
+ attention_mask: Optional[torch.Tensor] = None,
4773
+ temb: Optional[torch.Tensor] = None,
4774
+ scale: float = 1.0,
4775
+ ip_adapter_masks: Optional[torch.Tensor] = None,
4776
+ ):
4777
+ residual = hidden_states
4778
+
4779
+ # separate ip_hidden_states from encoder_hidden_states
4780
+ if encoder_hidden_states is not None:
4781
+ if isinstance(encoder_hidden_states, tuple):
4782
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
4783
+ else:
4784
+ deprecation_message = (
4785
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
4786
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
4787
+ )
4788
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
4789
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
4790
+ encoder_hidden_states, ip_hidden_states = (
4791
+ encoder_hidden_states[:, :end_pos, :],
4792
+ [encoder_hidden_states[:, end_pos:, :]],
4793
+ )
4794
+
4795
+ if attn.spatial_norm is not None:
4796
+ hidden_states = attn.spatial_norm(hidden_states, temb)
4797
+
4798
+ input_ndim = hidden_states.ndim
4799
+
4800
+ if input_ndim == 4:
4801
+ batch_size, channel, height, width = hidden_states.shape
4802
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4803
+
4804
+ batch_size, sequence_length, _ = (
4805
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
4806
+ )
4807
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4808
+
4809
+ if attn.group_norm is not None:
4810
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
4811
+
4812
+ query = attn.to_q(hidden_states)
4813
+
4814
+ if encoder_hidden_states is None:
4815
+ encoder_hidden_states = hidden_states
4816
+ elif attn.norm_cross:
4817
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
4818
+
4819
+ key = attn.to_k(encoder_hidden_states)
4820
+ value = attn.to_v(encoder_hidden_states)
4821
+
4822
+ query = attn.head_to_batch_dim(query)
4823
+ key = attn.head_to_batch_dim(key)
4824
+ value = attn.head_to_batch_dim(value)
4825
+
4826
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
4827
+ hidden_states = torch.bmm(attention_probs, value)
4828
+ hidden_states = attn.batch_to_head_dim(hidden_states)
4829
+
4830
+ if ip_adapter_masks is not None:
4831
+ if not isinstance(ip_adapter_masks, List):
4832
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
4833
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
4834
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
4835
+ raise ValueError(
4836
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
4837
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
4838
+ f"({len(ip_hidden_states)})"
4839
+ )
4840
+ else:
4841
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
4842
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
4843
+ raise ValueError(
4844
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
4845
+ "[1, num_images_for_ip_adapter, height, width]."
4846
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
4847
+ )
4848
+ if mask.shape[1] != ip_state.shape[1]:
4849
+ raise ValueError(
4850
+ f"Number of masks ({mask.shape[1]}) does not match "
4851
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
4852
+ )
4853
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
4854
+ raise ValueError(
4855
+ f"Number of masks ({mask.shape[1]}) does not match "
4856
+ f"number of scales ({len(scale)}) at index {index}"
4857
+ )
4858
+ else:
4859
+ ip_adapter_masks = [None] * len(self.scale)
4860
+
4861
+ # for ip-adapter
4862
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
4863
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
4864
+ ):
4865
+ skip = False
4866
+ if isinstance(scale, list):
4867
+ if all(s == 0 for s in scale):
4868
+ skip = True
4869
+ elif scale == 0:
4870
+ skip = True
4871
+ if not skip:
4872
+ if mask is not None:
4873
+ if not isinstance(scale, list):
4874
+ scale = [scale] * mask.shape[1]
4875
+
4876
+ current_num_images = mask.shape[1]
4877
+ for i in range(current_num_images):
4878
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
4879
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
4880
+
4881
+ ip_key = attn.head_to_batch_dim(ip_key)
4882
+ ip_value = attn.head_to_batch_dim(ip_value)
4883
+
4884
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
4885
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
4886
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
4887
+
4888
+ mask_downsample = IPAdapterMaskProcessor.downsample(
4889
+ mask[:, i, :, :],
4890
+ batch_size,
4891
+ _current_ip_hidden_states.shape[1],
4892
+ _current_ip_hidden_states.shape[2],
4893
+ )
4894
+
4895
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4896
+
4897
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4898
+ else:
4899
+ ip_key = to_k_ip(current_ip_hidden_states)
4900
+ ip_value = to_v_ip(current_ip_hidden_states)
4901
+
4902
+ ip_key = attn.head_to_batch_dim(ip_key)
4903
+ ip_value = attn.head_to_batch_dim(ip_value)
4904
+
4905
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
4906
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
4907
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
4908
+
4909
+ hidden_states = hidden_states + scale * current_ip_hidden_states
4910
+
4911
+ # linear proj
4912
+ hidden_states = attn.to_out[0](hidden_states)
4913
+ # dropout
4914
+ hidden_states = attn.to_out[1](hidden_states)
4915
+
4916
+ if input_ndim == 4:
4917
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
4918
+
4919
+ if attn.residual_connection:
4920
+ hidden_states = hidden_states + residual
3606
4921
 
3607
- def __init__(
3608
- self,
3609
- f_channels: int,
3610
- zq_channels: int,
3611
- ):
3612
- super().__init__()
3613
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
3614
- self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
3615
- self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4922
+ hidden_states = hidden_states / attn.rescale_output_factor
3616
4923
 
3617
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
3618
- f_size = f.shape[-2:]
3619
- zq = F.interpolate(zq, size=f_size, mode="nearest")
3620
- norm_f = self.norm_layer(f)
3621
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
3622
- return new_f
4924
+ return hidden_states
3623
4925
 
3624
4926
 
3625
- class IPAdapterAttnProcessor(nn.Module):
4927
+ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3626
4928
  r"""
3627
- Attention processor for Multiple IP-Adapters.
4929
+ Attention processor for IP-Adapter for PyTorch 2.0.
3628
4930
 
3629
4931
  Args:
3630
4932
  hidden_size (`int`):
@@ -3633,13 +4935,18 @@ class IPAdapterAttnProcessor(nn.Module):
3633
4935
  The number of channels in the `encoder_hidden_states`.
3634
4936
  num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
3635
4937
  The context length of the image features.
3636
- scale (`float` or List[`float`], defaults to 1.0):
4938
+ scale (`float` or `List[float]`, defaults to 1.0):
3637
4939
  the weight scale of image prompt.
3638
4940
  """
3639
4941
 
3640
4942
  def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
3641
4943
  super().__init__()
3642
4944
 
4945
+ if not hasattr(F, "scaled_dot_product_attention"):
4946
+ raise ImportError(
4947
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
4948
+ )
4949
+
3643
4950
  self.hidden_size = hidden_size
3644
4951
  self.cross_attention_dim = cross_attention_dim
3645
4952
 
@@ -3700,7 +5007,12 @@ class IPAdapterAttnProcessor(nn.Module):
3700
5007
  batch_size, sequence_length, _ = (
3701
5008
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3702
5009
  )
3703
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
5010
+
5011
+ if attention_mask is not None:
5012
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
5013
+ # scaled_dot_product_attention expects attention_mask shape to be
5014
+ # (batch, heads, source_length, target_length)
5015
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3704
5016
 
3705
5017
  if attn.group_norm is not None:
3706
5018
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3715,13 +5027,22 @@ class IPAdapterAttnProcessor(nn.Module):
3715
5027
  key = attn.to_k(encoder_hidden_states)
3716
5028
  value = attn.to_v(encoder_hidden_states)
3717
5029
 
3718
- query = attn.head_to_batch_dim(query)
3719
- key = attn.head_to_batch_dim(key)
3720
- value = attn.head_to_batch_dim(value)
5030
+ inner_dim = key.shape[-1]
5031
+ head_dim = inner_dim // attn.heads
3721
5032
 
3722
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
3723
- hidden_states = torch.bmm(attention_probs, value)
3724
- hidden_states = attn.batch_to_head_dim(hidden_states)
5033
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5034
+
5035
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5036
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5037
+
5038
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5039
+ # TODO: add support for attn.scale when we move to Torch 2.1
5040
+ hidden_states = F.scaled_dot_product_attention(
5041
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
5042
+ )
5043
+
5044
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
5045
+ hidden_states = hidden_states.to(query.dtype)
3725
5046
 
3726
5047
  if ip_adapter_masks is not None:
3727
5048
  if not isinstance(ip_adapter_masks, List):
@@ -3774,12 +5095,19 @@ class IPAdapterAttnProcessor(nn.Module):
3774
5095
  ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
3775
5096
  ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
3776
5097
 
3777
- ip_key = attn.head_to_batch_dim(ip_key)
3778
- ip_value = attn.head_to_batch_dim(ip_value)
5098
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5099
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3779
5100
 
3780
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
3781
- _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
3782
- _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
5101
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5102
+ # TODO: add support for attn.scale when we move to Torch 2.1
5103
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
5104
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
5105
+ )
5106
+
5107
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
5108
+ batch_size, -1, attn.heads * head_dim
5109
+ )
5110
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
3783
5111
 
3784
5112
  mask_downsample = IPAdapterMaskProcessor.downsample(
3785
5113
  mask[:, i, :, :],
@@ -3789,18 +5117,24 @@ class IPAdapterAttnProcessor(nn.Module):
3789
5117
  )
3790
5118
 
3791
5119
  mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
3792
-
3793
5120
  hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
3794
5121
  else:
3795
5122
  ip_key = to_k_ip(current_ip_hidden_states)
3796
5123
  ip_value = to_v_ip(current_ip_hidden_states)
3797
5124
 
3798
- ip_key = attn.head_to_batch_dim(ip_key)
3799
- ip_value = attn.head_to_batch_dim(ip_value)
5125
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5126
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3800
5127
 
3801
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
3802
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
3803
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
5128
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5129
+ # TODO: add support for attn.scale when we move to Torch 2.1
5130
+ current_ip_hidden_states = F.scaled_dot_product_attention(
5131
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
5132
+ )
5133
+
5134
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
5135
+ batch_size, -1, attn.heads * head_dim
5136
+ )
5137
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
3804
5138
 
3805
5139
  hidden_states = hidden_states + scale * current_ip_hidden_states
3806
5140
 
@@ -3820,9 +5154,9 @@ class IPAdapterAttnProcessor(nn.Module):
3820
5154
  return hidden_states
3821
5155
 
3822
5156
 
3823
- class IPAdapterAttnProcessor2_0(torch.nn.Module):
5157
+ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
3824
5158
  r"""
3825
- Attention processor for IP-Adapter for PyTorch 2.0.
5159
+ Attention processor for IP-Adapter using xFormers.
3826
5160
 
3827
5161
  Args:
3828
5162
  hidden_size (`int`):
@@ -3833,18 +5167,26 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3833
5167
  The context length of the image features.
3834
5168
  scale (`float` or `List[float]`, defaults to 1.0):
3835
5169
  the weight scale of image prompt.
5170
+ attention_op (`Callable`, *optional*, defaults to `None`):
5171
+ The base
5172
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
5173
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
5174
+ operator.
3836
5175
  """
3837
5176
 
3838
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
5177
+ def __init__(
5178
+ self,
5179
+ hidden_size,
5180
+ cross_attention_dim=None,
5181
+ num_tokens=(4,),
5182
+ scale=1.0,
5183
+ attention_op: Optional[Callable] = None,
5184
+ ):
3839
5185
  super().__init__()
3840
5186
 
3841
- if not hasattr(F, "scaled_dot_product_attention"):
3842
- raise ImportError(
3843
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3844
- )
3845
-
3846
5187
  self.hidden_size = hidden_size
3847
5188
  self.cross_attention_dim = cross_attention_dim
5189
+ self.attention_op = attention_op
3848
5190
 
3849
5191
  if not isinstance(num_tokens, (tuple, list)):
3850
5192
  num_tokens = [num_tokens]
@@ -3857,21 +5199,21 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3857
5199
  self.scale = scale
3858
5200
 
3859
5201
  self.to_k_ip = nn.ModuleList(
3860
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
5202
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
3861
5203
  )
3862
5204
  self.to_v_ip = nn.ModuleList(
3863
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
5205
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
3864
5206
  )
3865
5207
 
3866
5208
  def __call__(
3867
5209
  self,
3868
5210
  attn: Attention,
3869
- hidden_states: torch.Tensor,
3870
- encoder_hidden_states: Optional[torch.Tensor] = None,
3871
- attention_mask: Optional[torch.Tensor] = None,
3872
- temb: Optional[torch.Tensor] = None,
5211
+ hidden_states: torch.FloatTensor,
5212
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
5213
+ attention_mask: Optional[torch.FloatTensor] = None,
5214
+ temb: Optional[torch.FloatTensor] = None,
3873
5215
  scale: float = 1.0,
3874
- ip_adapter_masks: Optional[torch.Tensor] = None,
5216
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
3875
5217
  ):
3876
5218
  residual = hidden_states
3877
5219
 
@@ -3906,9 +5248,14 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3906
5248
 
3907
5249
  if attention_mask is not None:
3908
5250
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3909
- # scaled_dot_product_attention expects attention_mask shape to be
3910
- # (batch, heads, source_length, target_length)
3911
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
5251
+ # expand our mask's singleton query_tokens dimension:
5252
+ # [batch*heads, 1, key_tokens] ->
5253
+ # [batch*heads, query_tokens, key_tokens]
5254
+ # so that it can be added as a bias onto the attention scores that xformers computes:
5255
+ # [batch*heads, query_tokens, key_tokens]
5256
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
5257
+ _, query_tokens, _ = hidden_states.shape
5258
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
3912
5259
 
3913
5260
  if attn.group_norm is not None:
3914
5261
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3923,131 +5270,291 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3923
5270
  key = attn.to_k(encoder_hidden_states)
3924
5271
  value = attn.to_v(encoder_hidden_states)
3925
5272
 
5273
+ query = attn.head_to_batch_dim(query).contiguous()
5274
+ key = attn.head_to_batch_dim(key).contiguous()
5275
+ value = attn.head_to_batch_dim(value).contiguous()
5276
+
5277
+ hidden_states = xformers.ops.memory_efficient_attention(
5278
+ query, key, value, attn_bias=attention_mask, op=self.attention_op
5279
+ )
5280
+ hidden_states = hidden_states.to(query.dtype)
5281
+ hidden_states = attn.batch_to_head_dim(hidden_states)
5282
+
5283
+ if ip_hidden_states:
5284
+ if ip_adapter_masks is not None:
5285
+ if not isinstance(ip_adapter_masks, List):
5286
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
5287
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
5288
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
5289
+ raise ValueError(
5290
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
5291
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
5292
+ f"({len(ip_hidden_states)})"
5293
+ )
5294
+ else:
5295
+ for index, (mask, scale, ip_state) in enumerate(
5296
+ zip(ip_adapter_masks, self.scale, ip_hidden_states)
5297
+ ):
5298
+ if mask is None:
5299
+ continue
5300
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
5301
+ raise ValueError(
5302
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
5303
+ "[1, num_images_for_ip_adapter, height, width]."
5304
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
5305
+ )
5306
+ if mask.shape[1] != ip_state.shape[1]:
5307
+ raise ValueError(
5308
+ f"Number of masks ({mask.shape[1]}) does not match "
5309
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
5310
+ )
5311
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
5312
+ raise ValueError(
5313
+ f"Number of masks ({mask.shape[1]}) does not match "
5314
+ f"number of scales ({len(scale)}) at index {index}"
5315
+ )
5316
+ else:
5317
+ ip_adapter_masks = [None] * len(self.scale)
5318
+
5319
+ # for ip-adapter
5320
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
5321
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
5322
+ ):
5323
+ skip = False
5324
+ if isinstance(scale, list):
5325
+ if all(s == 0 for s in scale):
5326
+ skip = True
5327
+ elif scale == 0:
5328
+ skip = True
5329
+ if not skip:
5330
+ if mask is not None:
5331
+ mask = mask.to(torch.float16)
5332
+ if not isinstance(scale, list):
5333
+ scale = [scale] * mask.shape[1]
5334
+
5335
+ current_num_images = mask.shape[1]
5336
+ for i in range(current_num_images):
5337
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
5338
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
5339
+
5340
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
5341
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
5342
+
5343
+ _current_ip_hidden_states = xformers.ops.memory_efficient_attention(
5344
+ query, ip_key, ip_value, op=self.attention_op
5345
+ )
5346
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
5347
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
5348
+
5349
+ mask_downsample = IPAdapterMaskProcessor.downsample(
5350
+ mask[:, i, :, :],
5351
+ batch_size,
5352
+ _current_ip_hidden_states.shape[1],
5353
+ _current_ip_hidden_states.shape[2],
5354
+ )
5355
+
5356
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
5357
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
5358
+ else:
5359
+ ip_key = to_k_ip(current_ip_hidden_states)
5360
+ ip_value = to_v_ip(current_ip_hidden_states)
5361
+
5362
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
5363
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
5364
+
5365
+ current_ip_hidden_states = xformers.ops.memory_efficient_attention(
5366
+ query, ip_key, ip_value, op=self.attention_op
5367
+ )
5368
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
5369
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
5370
+
5371
+ hidden_states = hidden_states + scale * current_ip_hidden_states
5372
+
5373
+ # linear proj
5374
+ hidden_states = attn.to_out[0](hidden_states)
5375
+ # dropout
5376
+ hidden_states = attn.to_out[1](hidden_states)
5377
+
5378
+ if input_ndim == 4:
5379
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
5380
+
5381
+ if attn.residual_connection:
5382
+ hidden_states = hidden_states + residual
5383
+
5384
+ hidden_states = hidden_states / attn.rescale_output_factor
5385
+
5386
+ return hidden_states
5387
+
5388
+
5389
+ class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
5390
+ """
5391
+ Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
5392
+ additional image-based information and timestep embeddings.
5393
+
5394
+ Args:
5395
+ hidden_size (`int`):
5396
+ The number of hidden channels.
5397
+ ip_hidden_states_dim (`int`):
5398
+ The image feature dimension.
5399
+ head_dim (`int`):
5400
+ The number of head channels.
5401
+ timesteps_emb_dim (`int`, defaults to 1280):
5402
+ The number of input channels for timestep embedding.
5403
+ scale (`float`, defaults to 0.5):
5404
+ IP-Adapter scale.
5405
+ """
5406
+
5407
+ def __init__(
5408
+ self,
5409
+ hidden_size: int,
5410
+ ip_hidden_states_dim: int,
5411
+ head_dim: int,
5412
+ timesteps_emb_dim: int = 1280,
5413
+ scale: float = 0.5,
5414
+ ):
5415
+ super().__init__()
5416
+
5417
+ # To prevent circular import
5418
+ from .normalization import AdaLayerNorm, RMSNorm
5419
+
5420
+ self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
5421
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
5422
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
5423
+ self.norm_q = RMSNorm(head_dim, 1e-6)
5424
+ self.norm_k = RMSNorm(head_dim, 1e-6)
5425
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
5426
+ self.scale = scale
5427
+
5428
+ def __call__(
5429
+ self,
5430
+ attn: Attention,
5431
+ hidden_states: torch.FloatTensor,
5432
+ encoder_hidden_states: torch.FloatTensor = None,
5433
+ attention_mask: Optional[torch.FloatTensor] = None,
5434
+ ip_hidden_states: torch.FloatTensor = None,
5435
+ temb: torch.FloatTensor = None,
5436
+ ) -> torch.FloatTensor:
5437
+ """
5438
+ Perform the attention computation, integrating image features (if provided) and timestep embeddings.
5439
+
5440
+ If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
5441
+
5442
+ Args:
5443
+ attn (`Attention`):
5444
+ Attention instance.
5445
+ hidden_states (`torch.FloatTensor`):
5446
+ Input `hidden_states`.
5447
+ encoder_hidden_states (`torch.FloatTensor`, *optional*):
5448
+ The encoder hidden states.
5449
+ attention_mask (`torch.FloatTensor`, *optional*):
5450
+ Attention mask.
5451
+ ip_hidden_states (`torch.FloatTensor`, *optional*):
5452
+ Image embeddings.
5453
+ temb (`torch.FloatTensor`, *optional*):
5454
+ Timestep embeddings.
5455
+
5456
+ Returns:
5457
+ `torch.FloatTensor`: Output hidden states.
5458
+ """
5459
+ residual = hidden_states
5460
+
5461
+ batch_size = hidden_states.shape[0]
5462
+
5463
+ # `sample` projections.
5464
+ query = attn.to_q(hidden_states)
5465
+ key = attn.to_k(hidden_states)
5466
+ value = attn.to_v(hidden_states)
5467
+
3926
5468
  inner_dim = key.shape[-1]
3927
5469
  head_dim = inner_dim // attn.heads
3928
5470
 
3929
5471
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3930
-
3931
5472
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3932
5473
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5474
+ img_query = query
5475
+ img_key = key
5476
+ img_value = value
3933
5477
 
3934
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
3935
- # TODO: add support for attn.scale when we move to Torch 2.1
3936
- hidden_states = F.scaled_dot_product_attention(
3937
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3938
- )
5478
+ if attn.norm_q is not None:
5479
+ query = attn.norm_q(query)
5480
+ if attn.norm_k is not None:
5481
+ key = attn.norm_k(key)
3939
5482
 
3940
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3941
- hidden_states = hidden_states.to(query.dtype)
5483
+ # `context` projections.
5484
+ if encoder_hidden_states is not None:
5485
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
5486
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
5487
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3942
5488
 
3943
- if ip_adapter_masks is not None:
3944
- if not isinstance(ip_adapter_masks, List):
3945
- # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
3946
- ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
3947
- if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
3948
- raise ValueError(
3949
- f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
3950
- f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
3951
- f"({len(ip_hidden_states)})"
3952
- )
3953
- else:
3954
- for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
3955
- if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
3956
- raise ValueError(
3957
- "Each element of the ip_adapter_masks array should be a tensor with shape "
3958
- "[1, num_images_for_ip_adapter, height, width]."
3959
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
3960
- )
3961
- if mask.shape[1] != ip_state.shape[1]:
3962
- raise ValueError(
3963
- f"Number of masks ({mask.shape[1]}) does not match "
3964
- f"number of ip images ({ip_state.shape[1]}) at index {index}"
3965
- )
3966
- if isinstance(scale, list) and not len(scale) == mask.shape[1]:
3967
- raise ValueError(
3968
- f"Number of masks ({mask.shape[1]}) does not match "
3969
- f"number of scales ({len(scale)}) at index {index}"
3970
- )
3971
- else:
3972
- ip_adapter_masks = [None] * len(self.scale)
5489
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
5490
+ batch_size, -1, attn.heads, head_dim
5491
+ ).transpose(1, 2)
5492
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
5493
+ batch_size, -1, attn.heads, head_dim
5494
+ ).transpose(1, 2)
5495
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
5496
+ batch_size, -1, attn.heads, head_dim
5497
+ ).transpose(1, 2)
3973
5498
 
3974
- # for ip-adapter
3975
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
3976
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
3977
- ):
3978
- skip = False
3979
- if isinstance(scale, list):
3980
- if all(s == 0 for s in scale):
3981
- skip = True
3982
- elif scale == 0:
3983
- skip = True
3984
- if not skip:
3985
- if mask is not None:
3986
- if not isinstance(scale, list):
3987
- scale = [scale] * mask.shape[1]
5499
+ if attn.norm_added_q is not None:
5500
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
5501
+ if attn.norm_added_k is not None:
5502
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3988
5503
 
3989
- current_num_images = mask.shape[1]
3990
- for i in range(current_num_images):
3991
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
3992
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
5504
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
5505
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
5506
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
3993
5507
 
3994
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3995
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5508
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
5509
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
5510
+ hidden_states = hidden_states.to(query.dtype)
3996
5511
 
3997
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
3998
- # TODO: add support for attn.scale when we move to Torch 2.1
3999
- _current_ip_hidden_states = F.scaled_dot_product_attention(
4000
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
4001
- )
5512
+ if encoder_hidden_states is not None:
5513
+ # Split the attention outputs.
5514
+ hidden_states, encoder_hidden_states = (
5515
+ hidden_states[:, : residual.shape[1]],
5516
+ hidden_states[:, residual.shape[1] :],
5517
+ )
5518
+ if not attn.context_pre_only:
5519
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
4002
5520
 
4003
- _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
4004
- batch_size, -1, attn.heads * head_dim
4005
- )
4006
- _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
5521
+ # IP Adapter
5522
+ if self.scale != 0 and ip_hidden_states is not None:
5523
+ # Norm image features
5524
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
4007
5525
 
4008
- mask_downsample = IPAdapterMaskProcessor.downsample(
4009
- mask[:, i, :, :],
4010
- batch_size,
4011
- _current_ip_hidden_states.shape[1],
4012
- _current_ip_hidden_states.shape[2],
4013
- )
5526
+ # To k and v
5527
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
5528
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
4014
5529
 
4015
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4016
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4017
- else:
4018
- ip_key = to_k_ip(current_ip_hidden_states)
4019
- ip_value = to_v_ip(current_ip_hidden_states)
5530
+ # Reshape
5531
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5532
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4020
5533
 
4021
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
4022
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5534
+ # Norm
5535
+ query = self.norm_q(img_query)
5536
+ img_key = self.norm_k(img_key)
5537
+ ip_key = self.norm_ip_k(ip_key)
4023
5538
 
4024
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
4025
- # TODO: add support for attn.scale when we move to Torch 2.1
4026
- current_ip_hidden_states = F.scaled_dot_product_attention(
4027
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
4028
- )
5539
+ # cat img
5540
+ key = torch.cat([img_key, ip_key], dim=2)
5541
+ value = torch.cat([img_value, ip_value], dim=2)
4029
5542
 
4030
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
4031
- batch_size, -1, attn.heads * head_dim
4032
- )
4033
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
5543
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
5544
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
5545
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
4034
5546
 
4035
- hidden_states = hidden_states + scale * current_ip_hidden_states
5547
+ hidden_states = hidden_states + ip_hidden_states * self.scale
4036
5548
 
4037
5549
  # linear proj
4038
5550
  hidden_states = attn.to_out[0](hidden_states)
4039
5551
  # dropout
4040
5552
  hidden_states = attn.to_out[1](hidden_states)
4041
5553
 
4042
- if input_ndim == 4:
4043
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
4044
-
4045
- if attn.residual_connection:
4046
- hidden_states = hidden_states + residual
4047
-
4048
- hidden_states = hidden_states / attn.rescale_output_factor
4049
-
4050
- return hidden_states
5554
+ if encoder_hidden_states is not None:
5555
+ return hidden_states, encoder_hidden_states
5556
+ else:
5557
+ return hidden_states
4051
5558
 
4052
5559
 
4053
5560
  class PAGIdentitySelfAttnProcessor2_0:
@@ -4252,22 +5759,98 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
4252
5759
  return hidden_states
4253
5760
 
4254
5761
 
5762
+ class SanaMultiscaleAttnProcessor2_0:
5763
+ r"""
5764
+ Processor for implementing multiscale quadratic attention.
5765
+ """
5766
+
5767
+ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
5768
+ height, width = hidden_states.shape[-2:]
5769
+ if height * width > attn.attention_head_dim:
5770
+ use_linear_attention = True
5771
+ else:
5772
+ use_linear_attention = False
5773
+
5774
+ residual = hidden_states
5775
+
5776
+ batch_size, _, height, width = list(hidden_states.size())
5777
+ original_dtype = hidden_states.dtype
5778
+
5779
+ hidden_states = hidden_states.movedim(1, -1)
5780
+ query = attn.to_q(hidden_states)
5781
+ key = attn.to_k(hidden_states)
5782
+ value = attn.to_v(hidden_states)
5783
+ hidden_states = torch.cat([query, key, value], dim=3)
5784
+ hidden_states = hidden_states.movedim(-1, 1)
5785
+
5786
+ multi_scale_qkv = [hidden_states]
5787
+ for block in attn.to_qkv_multiscale:
5788
+ multi_scale_qkv.append(block(hidden_states))
5789
+
5790
+ hidden_states = torch.cat(multi_scale_qkv, dim=1)
5791
+
5792
+ if use_linear_attention:
5793
+ # for linear attention upcast hidden_states to float32
5794
+ hidden_states = hidden_states.to(dtype=torch.float32)
5795
+
5796
+ hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
5797
+
5798
+ query, key, value = hidden_states.chunk(3, dim=2)
5799
+ query = attn.nonlinearity(query)
5800
+ key = attn.nonlinearity(key)
5801
+
5802
+ if use_linear_attention:
5803
+ hidden_states = attn.apply_linear_attention(query, key, value)
5804
+ hidden_states = hidden_states.to(dtype=original_dtype)
5805
+ else:
5806
+ hidden_states = attn.apply_quadratic_attention(query, key, value)
5807
+
5808
+ hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
5809
+ hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5810
+
5811
+ if attn.norm_type == "rms_norm":
5812
+ hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5813
+ else:
5814
+ hidden_states = attn.norm_out(hidden_states)
5815
+
5816
+ if attn.residual_connection:
5817
+ hidden_states = hidden_states + residual
5818
+
5819
+ return hidden_states
5820
+
5821
+
4255
5822
  class LoRAAttnProcessor:
5823
+ r"""
5824
+ Processor for implementing attention with LoRA.
5825
+ """
5826
+
4256
5827
  def __init__(self):
4257
5828
  pass
4258
5829
 
4259
5830
 
4260
5831
  class LoRAAttnProcessor2_0:
5832
+ r"""
5833
+ Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
5834
+ """
5835
+
4261
5836
  def __init__(self):
4262
5837
  pass
4263
5838
 
4264
5839
 
4265
5840
  class LoRAXFormersAttnProcessor:
5841
+ r"""
5842
+ Processor for implementing attention with LoRA using xFormers.
5843
+ """
5844
+
4266
5845
  def __init__(self):
4267
5846
  pass
4268
5847
 
4269
5848
 
4270
5849
  class LoRAAttnAddedKVProcessor:
5850
+ r"""
5851
+ Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
5852
+ """
5853
+
4271
5854
  def __init__(self):
4272
5855
  pass
4273
5856
 
@@ -4283,6 +5866,165 @@ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
4283
5866
  super().__init__()
4284
5867
 
4285
5868
 
5869
+ class SanaLinearAttnProcessor2_0:
5870
+ r"""
5871
+ Processor for implementing scaled dot-product linear attention.
5872
+ """
5873
+
5874
+ def __call__(
5875
+ self,
5876
+ attn: Attention,
5877
+ hidden_states: torch.Tensor,
5878
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5879
+ attention_mask: Optional[torch.Tensor] = None,
5880
+ ) -> torch.Tensor:
5881
+ original_dtype = hidden_states.dtype
5882
+
5883
+ if encoder_hidden_states is None:
5884
+ encoder_hidden_states = hidden_states
5885
+
5886
+ query = attn.to_q(hidden_states)
5887
+ key = attn.to_k(encoder_hidden_states)
5888
+ value = attn.to_v(encoder_hidden_states)
5889
+
5890
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5891
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5892
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5893
+
5894
+ query = F.relu(query)
5895
+ key = F.relu(key)
5896
+
5897
+ query, key, value = query.float(), key.float(), value.float()
5898
+
5899
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5900
+ scores = torch.matmul(value, key)
5901
+ hidden_states = torch.matmul(scores, query)
5902
+
5903
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
5904
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
5905
+ hidden_states = hidden_states.to(original_dtype)
5906
+
5907
+ hidden_states = attn.to_out[0](hidden_states)
5908
+ hidden_states = attn.to_out[1](hidden_states)
5909
+
5910
+ if original_dtype == torch.float16:
5911
+ hidden_states = hidden_states.clip(-65504, 65504)
5912
+
5913
+ return hidden_states
5914
+
5915
+
5916
+ class PAGCFGSanaLinearAttnProcessor2_0:
5917
+ r"""
5918
+ Processor for implementing scaled dot-product linear attention.
5919
+ """
5920
+
5921
+ def __call__(
5922
+ self,
5923
+ attn: Attention,
5924
+ hidden_states: torch.Tensor,
5925
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5926
+ attention_mask: Optional[torch.Tensor] = None,
5927
+ ) -> torch.Tensor:
5928
+ original_dtype = hidden_states.dtype
5929
+
5930
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
5931
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
5932
+
5933
+ query = attn.to_q(hidden_states_org)
5934
+ key = attn.to_k(hidden_states_org)
5935
+ value = attn.to_v(hidden_states_org)
5936
+
5937
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5938
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5939
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5940
+
5941
+ query = F.relu(query)
5942
+ key = F.relu(key)
5943
+
5944
+ query, key, value = query.float(), key.float(), value.float()
5945
+
5946
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5947
+ scores = torch.matmul(value, key)
5948
+ hidden_states_org = torch.matmul(scores, query)
5949
+
5950
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
5951
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
5952
+ hidden_states_org = hidden_states_org.to(original_dtype)
5953
+
5954
+ hidden_states_org = attn.to_out[0](hidden_states_org)
5955
+ hidden_states_org = attn.to_out[1](hidden_states_org)
5956
+
5957
+ # perturbed path (identity attention)
5958
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
5959
+
5960
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
5961
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
5962
+
5963
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
5964
+
5965
+ if original_dtype == torch.float16:
5966
+ hidden_states = hidden_states.clip(-65504, 65504)
5967
+
5968
+ return hidden_states
5969
+
5970
+
5971
+ class PAGIdentitySanaLinearAttnProcessor2_0:
5972
+ r"""
5973
+ Processor for implementing scaled dot-product linear attention.
5974
+ """
5975
+
5976
+ def __call__(
5977
+ self,
5978
+ attn: Attention,
5979
+ hidden_states: torch.Tensor,
5980
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5981
+ attention_mask: Optional[torch.Tensor] = None,
5982
+ ) -> torch.Tensor:
5983
+ original_dtype = hidden_states.dtype
5984
+
5985
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
5986
+
5987
+ query = attn.to_q(hidden_states_org)
5988
+ key = attn.to_k(hidden_states_org)
5989
+ value = attn.to_v(hidden_states_org)
5990
+
5991
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5992
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5993
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5994
+
5995
+ query = F.relu(query)
5996
+ key = F.relu(key)
5997
+
5998
+ query, key, value = query.float(), key.float(), value.float()
5999
+
6000
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
6001
+ scores = torch.matmul(value, key)
6002
+ hidden_states_org = torch.matmul(scores, query)
6003
+
6004
+ if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
6005
+ hidden_states_org = hidden_states_org.float()
6006
+
6007
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
6008
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
6009
+ hidden_states_org = hidden_states_org.to(original_dtype)
6010
+
6011
+ hidden_states_org = attn.to_out[0](hidden_states_org)
6012
+ hidden_states_org = attn.to_out[1](hidden_states_org)
6013
+
6014
+ # perturbed path (identity attention)
6015
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
6016
+
6017
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
6018
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
6019
+
6020
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
6021
+
6022
+ if original_dtype == torch.float16:
6023
+ hidden_states = hidden_states.clip(-65504, 65504)
6024
+
6025
+ return hidden_states
6026
+
6027
+
4286
6028
  ADDED_KV_ATTENTION_PROCESSORS = (
4287
6029
  AttnAddedKVProcessor,
4288
6030
  SlicedAttnAddedKVProcessor,
@@ -4297,23 +6039,59 @@ CROSS_ATTENTION_PROCESSORS = (
4297
6039
  SlicedAttnProcessor,
4298
6040
  IPAdapterAttnProcessor,
4299
6041
  IPAdapterAttnProcessor2_0,
6042
+ FluxIPAdapterJointAttnProcessor2_0,
4300
6043
  )
4301
6044
 
4302
6045
  AttentionProcessor = Union[
4303
6046
  AttnProcessor,
4304
- AttnProcessor2_0,
4305
- FusedAttnProcessor2_0,
4306
- XFormersAttnProcessor,
4307
- SlicedAttnProcessor,
6047
+ CustomDiffusionAttnProcessor,
4308
6048
  AttnAddedKVProcessor,
4309
- SlicedAttnAddedKVProcessor,
4310
6049
  AttnAddedKVProcessor2_0,
6050
+ JointAttnProcessor2_0,
6051
+ PAGJointAttnProcessor2_0,
6052
+ PAGCFGJointAttnProcessor2_0,
6053
+ FusedJointAttnProcessor2_0,
6054
+ AllegroAttnProcessor2_0,
6055
+ AuraFlowAttnProcessor2_0,
6056
+ FusedAuraFlowAttnProcessor2_0,
6057
+ FluxAttnProcessor2_0,
6058
+ FluxAttnProcessor2_0_NPU,
6059
+ FusedFluxAttnProcessor2_0,
6060
+ FusedFluxAttnProcessor2_0_NPU,
6061
+ CogVideoXAttnProcessor2_0,
6062
+ FusedCogVideoXAttnProcessor2_0,
4311
6063
  XFormersAttnAddedKVProcessor,
4312
- CustomDiffusionAttnProcessor,
6064
+ XFormersAttnProcessor,
6065
+ XLAFlashAttnProcessor2_0,
6066
+ AttnProcessorNPU,
6067
+ AttnProcessor2_0,
6068
+ MochiVaeAttnProcessor2_0,
6069
+ MochiAttnProcessor2_0,
6070
+ StableAudioAttnProcessor2_0,
6071
+ HunyuanAttnProcessor2_0,
6072
+ FusedHunyuanAttnProcessor2_0,
6073
+ PAGHunyuanAttnProcessor2_0,
6074
+ PAGCFGHunyuanAttnProcessor2_0,
6075
+ LuminaAttnProcessor2_0,
6076
+ FusedAttnProcessor2_0,
4313
6077
  CustomDiffusionXFormersAttnProcessor,
4314
6078
  CustomDiffusionAttnProcessor2_0,
4315
- PAGCFGIdentitySelfAttnProcessor2_0,
6079
+ SlicedAttnProcessor,
6080
+ SlicedAttnAddedKVProcessor,
6081
+ SanaLinearAttnProcessor2_0,
6082
+ PAGCFGSanaLinearAttnProcessor2_0,
6083
+ PAGIdentitySanaLinearAttnProcessor2_0,
6084
+ SanaMultiscaleLinearAttention,
6085
+ SanaMultiscaleAttnProcessor2_0,
6086
+ SanaMultiscaleAttentionProjection,
6087
+ IPAdapterAttnProcessor,
6088
+ IPAdapterAttnProcessor2_0,
6089
+ IPAdapterXFormersAttnProcessor,
6090
+ SD3IPAdapterJointAttnProcessor2_0,
4316
6091
  PAGIdentitySelfAttnProcessor2_0,
4317
- PAGCFGHunyuanAttnProcessor2_0,
4318
- PAGHunyuanAttnProcessor2_0,
6092
+ PAGCFGIdentitySelfAttnProcessor2_0,
6093
+ LoRAAttnProcessor,
6094
+ LoRAAttnProcessor2_0,
6095
+ LoRAXFormersAttnProcessor,
6096
+ LoRAAttnAddedKVProcessor,
4319
6097
  ]