diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,8 @@ import torch.nn.functional as F
20
20
  from torch import nn
21
21
 
22
22
  from ..image_processor import IPAdapterMaskProcessor
23
- from ..utils import deprecate, logging
24
- from ..utils.import_utils import is_torch_npu_available, is_xformers_available
23
+ from ..utils import deprecate, is_torch_xla_available, logging
24
+ from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
25
25
  from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
26
26
 
27
27
 
@@ -36,6 +36,15 @@ if is_xformers_available():
36
36
  else:
37
37
  xformers = None
38
38
 
39
+ if is_torch_xla_available():
40
+ # flash attention pallas kernel is introduced in the torch_xla 2.3 release.
41
+ if is_torch_xla_version(">", "2.2"):
42
+ from torch_xla.experimental.custom_kernel import flash_attention
43
+ from torch_xla.runtime import is_spmd
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
39
48
 
40
49
  @maybe_allow_in_graph
41
50
  class Attention(nn.Module):
@@ -120,13 +129,16 @@ class Attention(nn.Module):
120
129
  _from_deprecated_attn_block: bool = False,
121
130
  processor: Optional["AttnProcessor"] = None,
122
131
  out_dim: int = None,
132
+ out_context_dim: int = None,
123
133
  context_pre_only=None,
124
134
  pre_only=False,
135
+ elementwise_affine: bool = True,
136
+ is_causal: bool = False,
125
137
  ):
126
138
  super().__init__()
127
139
 
128
140
  # To prevent circular import.
129
- from .normalization import FP32LayerNorm, RMSNorm
141
+ from .normalization import FP32LayerNorm, LpNorm, RMSNorm
130
142
 
131
143
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
132
144
  self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
@@ -141,8 +153,10 @@ class Attention(nn.Module):
141
153
  self.dropout = dropout
142
154
  self.fused_projections = False
143
155
  self.out_dim = out_dim if out_dim is not None else query_dim
156
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
144
157
  self.context_pre_only = context_pre_only
145
158
  self.pre_only = pre_only
159
+ self.is_causal = is_causal
146
160
 
147
161
  # we make use of this private variable to know whether this class is loaded
148
162
  # with an deprecated state dict so that we can convert it on the fly
@@ -179,20 +193,27 @@ class Attention(nn.Module):
179
193
  self.norm_q = None
180
194
  self.norm_k = None
181
195
  elif qk_norm == "layer_norm":
182
- self.norm_q = nn.LayerNorm(dim_head, eps=eps)
183
- self.norm_k = nn.LayerNorm(dim_head, eps=eps)
196
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
197
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
184
198
  elif qk_norm == "fp32_layer_norm":
185
199
  self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
186
200
  self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
187
201
  elif qk_norm == "layer_norm_across_heads":
188
- # Lumina applys qk norm across all heads
202
+ # Lumina applies qk norm across all heads
189
203
  self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
190
204
  self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
191
205
  elif qk_norm == "rms_norm":
192
206
  self.norm_q = RMSNorm(dim_head, eps=eps)
193
207
  self.norm_k = RMSNorm(dim_head, eps=eps)
208
+ elif qk_norm == "rms_norm_across_heads":
209
+ # LTX applies qk norm across all heads
210
+ self.norm_q = RMSNorm(dim_head * heads, eps=eps)
211
+ self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
212
+ elif qk_norm == "l2":
213
+ self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
214
+ self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
194
215
  else:
195
- raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
216
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
196
217
 
197
218
  if cross_attention_norm is None:
198
219
  self.norm_cross = None
@@ -233,14 +254,22 @@ class Attention(nn.Module):
233
254
  self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
234
255
  if self.context_pre_only is not None:
235
256
  self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
257
+ else:
258
+ self.add_q_proj = None
259
+ self.add_k_proj = None
260
+ self.add_v_proj = None
236
261
 
237
262
  if not self.pre_only:
238
263
  self.to_out = nn.ModuleList([])
239
264
  self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
240
265
  self.to_out.append(nn.Dropout(dropout))
266
+ else:
267
+ self.to_out = None
241
268
 
242
269
  if self.context_pre_only is not None and not self.context_pre_only:
243
- self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
270
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
271
+ else:
272
+ self.to_add_out = None
244
273
 
245
274
  if qk_norm is not None and added_kv_proj_dim is not None:
246
275
  if qk_norm == "fp32_layer_norm":
@@ -249,6 +278,10 @@ class Attention(nn.Module):
249
278
  elif qk_norm == "rms_norm":
250
279
  self.norm_added_q = RMSNorm(dim_head, eps=eps)
251
280
  self.norm_added_k = RMSNorm(dim_head, eps=eps)
281
+ else:
282
+ raise ValueError(
283
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
284
+ )
252
285
  else:
253
286
  self.norm_added_q = None
254
287
  self.norm_added_k = None
@@ -263,6 +296,33 @@ class Attention(nn.Module):
263
296
  )
264
297
  self.set_processor(processor)
265
298
 
299
+ def set_use_xla_flash_attention(
300
+ self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
301
+ ) -> None:
302
+ r"""
303
+ Set whether to use xla flash attention from `torch_xla` or not.
304
+
305
+ Args:
306
+ use_xla_flash_attention (`bool`):
307
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
308
+ partition_spec (`Tuple[]`, *optional*):
309
+ Specify the partition specification if using SPMD. Otherwise None.
310
+ """
311
+ if use_xla_flash_attention:
312
+ if not is_torch_xla_available:
313
+ raise "torch_xla is not available"
314
+ elif is_torch_xla_version("<", "2.3"):
315
+ raise "flash attention pallas kernel is supported from torch_xla version 2.3"
316
+ elif is_spmd() and is_torch_xla_version("<", "2.4"):
317
+ raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318
+ else:
319
+ processor = XLAFlashAttnProcessor2_0(partition_spec)
320
+ else:
321
+ processor = (
322
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
323
+ )
324
+ self.set_processor(processor)
325
+
266
326
  def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
267
327
  r"""
268
328
  Set whether to use npu flash attention from `torch_npu` or not.
@@ -306,6 +366,17 @@ class Attention(nn.Module):
306
366
  XFormersAttnAddedKVProcessor,
307
367
  ),
308
368
  )
369
+ is_ip_adapter = hasattr(self, "processor") and isinstance(
370
+ self.processor,
371
+ (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
372
+ )
373
+ is_joint_processor = hasattr(self, "processor") and isinstance(
374
+ self.processor,
375
+ (
376
+ JointAttnProcessor2_0,
377
+ XFormersJointAttnProcessor,
378
+ ),
379
+ )
309
380
 
310
381
  if use_memory_efficient_attention_xformers:
311
382
  if is_added_kv_processor and is_custom_diffusion:
@@ -356,6 +427,21 @@ class Attention(nn.Module):
356
427
  "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
357
428
  )
358
429
  processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
430
+ elif is_ip_adapter:
431
+ processor = IPAdapterXFormersAttnProcessor(
432
+ hidden_size=self.processor.hidden_size,
433
+ cross_attention_dim=self.processor.cross_attention_dim,
434
+ num_tokens=self.processor.num_tokens,
435
+ scale=self.processor.scale,
436
+ attention_op=attention_op,
437
+ )
438
+ processor.load_state_dict(self.processor.state_dict())
439
+ if hasattr(self.processor, "to_k_ip"):
440
+ processor.to(
441
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
442
+ )
443
+ elif is_joint_processor:
444
+ processor = XFormersJointAttnProcessor(attention_op=attention_op)
359
445
  else:
360
446
  processor = XFormersAttnProcessor(attention_op=attention_op)
361
447
  else:
@@ -374,6 +460,18 @@ class Attention(nn.Module):
374
460
  processor.load_state_dict(self.processor.state_dict())
375
461
  if hasattr(self.processor, "to_k_custom_diffusion"):
376
462
  processor.to(self.processor.to_k_custom_diffusion.weight.device)
463
+ elif is_ip_adapter:
464
+ processor = IPAdapterAttnProcessor2_0(
465
+ hidden_size=self.processor.hidden_size,
466
+ cross_attention_dim=self.processor.cross_attention_dim,
467
+ num_tokens=self.processor.num_tokens,
468
+ scale=self.processor.scale,
469
+ )
470
+ processor.load_state_dict(self.processor.state_dict())
471
+ if hasattr(self.processor, "to_k_ip"):
472
+ processor.to(
473
+ device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
474
+ )
377
475
  else:
378
476
  # set attention processor
379
477
  # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
@@ -477,7 +575,7 @@ class Attention(nn.Module):
477
575
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
478
576
 
479
577
  attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
480
- quiet_attn_parameters = {"ip_adapter_masks"}
578
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
481
579
  unused_kwargs = [
482
580
  k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
483
581
  ]
@@ -692,7 +790,11 @@ class Attention(nn.Module):
692
790
  self.to_kv.bias.copy_(concatenated_bias)
693
791
 
694
792
  # handle added projections for SD3 and others.
695
- if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
793
+ if (
794
+ getattr(self, "add_q_proj", None) is not None
795
+ and getattr(self, "add_k_proj", None) is not None
796
+ and getattr(self, "add_v_proj", None) is not None
797
+ ):
696
798
  concatenated_weights = torch.cat(
697
799
  [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
698
800
  )
@@ -712,6 +814,269 @@ class Attention(nn.Module):
712
814
  self.fused_projections = fuse
713
815
 
714
816
 
817
+ class SanaMultiscaleAttentionProjection(nn.Module):
818
+ def __init__(
819
+ self,
820
+ in_channels: int,
821
+ num_attention_heads: int,
822
+ kernel_size: int,
823
+ ) -> None:
824
+ super().__init__()
825
+
826
+ channels = 3 * in_channels
827
+ self.proj_in = nn.Conv2d(
828
+ channels,
829
+ channels,
830
+ kernel_size,
831
+ padding=kernel_size // 2,
832
+ groups=channels,
833
+ bias=False,
834
+ )
835
+ self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
836
+
837
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
838
+ hidden_states = self.proj_in(hidden_states)
839
+ hidden_states = self.proj_out(hidden_states)
840
+ return hidden_states
841
+
842
+
843
+ class SanaMultiscaleLinearAttention(nn.Module):
844
+ r"""Lightweight multi-scale linear attention"""
845
+
846
+ def __init__(
847
+ self,
848
+ in_channels: int,
849
+ out_channels: int,
850
+ num_attention_heads: Optional[int] = None,
851
+ attention_head_dim: int = 8,
852
+ mult: float = 1.0,
853
+ norm_type: str = "batch_norm",
854
+ kernel_sizes: Tuple[int, ...] = (5,),
855
+ eps: float = 1e-15,
856
+ residual_connection: bool = False,
857
+ ):
858
+ super().__init__()
859
+
860
+ # To prevent circular import
861
+ from .normalization import get_normalization
862
+
863
+ self.eps = eps
864
+ self.attention_head_dim = attention_head_dim
865
+ self.norm_type = norm_type
866
+ self.residual_connection = residual_connection
867
+
868
+ num_attention_heads = (
869
+ int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
870
+ )
871
+ inner_dim = num_attention_heads * attention_head_dim
872
+
873
+ self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
874
+ self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
875
+ self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
876
+
877
+ self.to_qkv_multiscale = nn.ModuleList()
878
+ for kernel_size in kernel_sizes:
879
+ self.to_qkv_multiscale.append(
880
+ SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
881
+ )
882
+
883
+ self.nonlinearity = nn.ReLU()
884
+ self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
885
+ self.norm_out = get_normalization(norm_type, num_features=out_channels)
886
+
887
+ self.processor = SanaMultiscaleAttnProcessor2_0()
888
+
889
+ def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
890
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
891
+ scores = torch.matmul(value, key.transpose(-1, -2))
892
+ hidden_states = torch.matmul(scores, query)
893
+
894
+ hidden_states = hidden_states.to(dtype=torch.float32)
895
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
896
+ return hidden_states
897
+
898
+ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
899
+ scores = torch.matmul(key.transpose(-1, -2), query)
900
+ scores = scores.to(dtype=torch.float32)
901
+ scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
902
+ hidden_states = torch.matmul(value, scores)
903
+ return hidden_states
904
+
905
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906
+ return self.processor(self, hidden_states)
907
+
908
+
909
+ class MochiAttention(nn.Module):
910
+ def __init__(
911
+ self,
912
+ query_dim: int,
913
+ added_kv_proj_dim: int,
914
+ processor: "MochiAttnProcessor2_0",
915
+ heads: int = 8,
916
+ dim_head: int = 64,
917
+ dropout: float = 0.0,
918
+ bias: bool = False,
919
+ added_proj_bias: bool = True,
920
+ out_dim: Optional[int] = None,
921
+ out_context_dim: Optional[int] = None,
922
+ out_bias: bool = True,
923
+ context_pre_only: bool = False,
924
+ eps: float = 1e-5,
925
+ ):
926
+ super().__init__()
927
+ from .normalization import MochiRMSNorm
928
+
929
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
930
+ self.out_dim = out_dim if out_dim is not None else query_dim
931
+ self.out_context_dim = out_context_dim if out_context_dim else query_dim
932
+ self.context_pre_only = context_pre_only
933
+
934
+ self.heads = out_dim // dim_head if out_dim is not None else heads
935
+
936
+ self.norm_q = MochiRMSNorm(dim_head, eps, True)
937
+ self.norm_k = MochiRMSNorm(dim_head, eps, True)
938
+ self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
939
+ self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
940
+
941
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
942
+ self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
943
+ self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
944
+
945
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
946
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
947
+ if self.context_pre_only is not None:
948
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
949
+
950
+ self.to_out = nn.ModuleList([])
951
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
952
+ self.to_out.append(nn.Dropout(dropout))
953
+
954
+ if not self.context_pre_only:
955
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
956
+
957
+ self.processor = processor
958
+
959
+ def forward(
960
+ self,
961
+ hidden_states: torch.Tensor,
962
+ encoder_hidden_states: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.Tensor] = None,
964
+ **kwargs,
965
+ ):
966
+ return self.processor(
967
+ self,
968
+ hidden_states,
969
+ encoder_hidden_states=encoder_hidden_states,
970
+ attention_mask=attention_mask,
971
+ **kwargs,
972
+ )
973
+
974
+
975
+ class MochiAttnProcessor2_0:
976
+ """Attention processor used in Mochi."""
977
+
978
+ def __init__(self):
979
+ if not hasattr(F, "scaled_dot_product_attention"):
980
+ raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
981
+
982
+ def __call__(
983
+ self,
984
+ attn: "MochiAttention",
985
+ hidden_states: torch.Tensor,
986
+ encoder_hidden_states: torch.Tensor,
987
+ attention_mask: torch.Tensor,
988
+ image_rotary_emb: Optional[torch.Tensor] = None,
989
+ ) -> torch.Tensor:
990
+ query = attn.to_q(hidden_states)
991
+ key = attn.to_k(hidden_states)
992
+ value = attn.to_v(hidden_states)
993
+
994
+ query = query.unflatten(2, (attn.heads, -1))
995
+ key = key.unflatten(2, (attn.heads, -1))
996
+ value = value.unflatten(2, (attn.heads, -1))
997
+
998
+ if attn.norm_q is not None:
999
+ query = attn.norm_q(query)
1000
+ if attn.norm_k is not None:
1001
+ key = attn.norm_k(key)
1002
+
1003
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
1004
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
1005
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
1006
+
1007
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
1008
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
1009
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
1010
+
1011
+ if attn.norm_added_q is not None:
1012
+ encoder_query = attn.norm_added_q(encoder_query)
1013
+ if attn.norm_added_k is not None:
1014
+ encoder_key = attn.norm_added_k(encoder_key)
1015
+
1016
+ if image_rotary_emb is not None:
1017
+
1018
+ def apply_rotary_emb(x, freqs_cos, freqs_sin):
1019
+ x_even = x[..., 0::2].float()
1020
+ x_odd = x[..., 1::2].float()
1021
+
1022
+ cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
1023
+ sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
1024
+
1025
+ return torch.stack([cos, sin], dim=-1).flatten(-2)
1026
+
1027
+ query = apply_rotary_emb(query, *image_rotary_emb)
1028
+ key = apply_rotary_emb(key, *image_rotary_emb)
1029
+
1030
+ query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
1031
+ encoder_query, encoder_key, encoder_value = (
1032
+ encoder_query.transpose(1, 2),
1033
+ encoder_key.transpose(1, 2),
1034
+ encoder_value.transpose(1, 2),
1035
+ )
1036
+
1037
+ sequence_length = query.size(2)
1038
+ encoder_sequence_length = encoder_query.size(2)
1039
+ total_length = sequence_length + encoder_sequence_length
1040
+
1041
+ batch_size, heads, _, dim = query.shape
1042
+ attn_outputs = []
1043
+ for idx in range(batch_size):
1044
+ mask = attention_mask[idx][None, :]
1045
+ valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
1046
+
1047
+ valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
1048
+ valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
1049
+ valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
1050
+
1051
+ valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
1052
+ valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
1053
+ valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
1054
+
1055
+ attn_output = F.scaled_dot_product_attention(
1056
+ valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
1057
+ )
1058
+ valid_sequence_length = attn_output.size(2)
1059
+ attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
1060
+ attn_outputs.append(attn_output)
1061
+
1062
+ hidden_states = torch.cat(attn_outputs, dim=0)
1063
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
1064
+
1065
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
1066
+ (sequence_length, encoder_sequence_length), dim=1
1067
+ )
1068
+
1069
+ # linear proj
1070
+ hidden_states = attn.to_out[0](hidden_states)
1071
+ # dropout
1072
+ hidden_states = attn.to_out[1](hidden_states)
1073
+
1074
+ if hasattr(attn, "to_add_out"):
1075
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1076
+
1077
+ return hidden_states, encoder_hidden_states
1078
+
1079
+
715
1080
  class AttnProcessor:
716
1081
  r"""
717
1082
  Default processor for performing attention-related computations.
@@ -1049,61 +1414,72 @@ class JointAttnProcessor2_0:
1049
1414
  ) -> torch.FloatTensor:
1050
1415
  residual = hidden_states
1051
1416
 
1052
- input_ndim = hidden_states.ndim
1053
- if input_ndim == 4:
1054
- batch_size, channel, height, width = hidden_states.shape
1055
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1056
- context_input_ndim = encoder_hidden_states.ndim
1057
- if context_input_ndim == 4:
1058
- batch_size, channel, height, width = encoder_hidden_states.shape
1059
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1060
-
1061
- batch_size = encoder_hidden_states.shape[0]
1417
+ batch_size = hidden_states.shape[0]
1062
1418
 
1063
1419
  # `sample` projections.
1064
1420
  query = attn.to_q(hidden_states)
1065
1421
  key = attn.to_k(hidden_states)
1066
1422
  value = attn.to_v(hidden_states)
1067
1423
 
1068
- # `context` projections.
1069
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1070
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1071
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1072
-
1073
- # attention
1074
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1075
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1076
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1077
-
1078
1424
  inner_dim = key.shape[-1]
1079
1425
  head_dim = inner_dim // attn.heads
1426
+
1080
1427
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1081
1428
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1082
1429
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1083
1430
 
1431
+ if attn.norm_q is not None:
1432
+ query = attn.norm_q(query)
1433
+ if attn.norm_k is not None:
1434
+ key = attn.norm_k(key)
1435
+
1436
+ # `context` projections.
1437
+ if encoder_hidden_states is not None:
1438
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1439
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1440
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1441
+
1442
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1443
+ batch_size, -1, attn.heads, head_dim
1444
+ ).transpose(1, 2)
1445
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1446
+ batch_size, -1, attn.heads, head_dim
1447
+ ).transpose(1, 2)
1448
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1449
+ batch_size, -1, attn.heads, head_dim
1450
+ ).transpose(1, 2)
1451
+
1452
+ if attn.norm_added_q is not None:
1453
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1454
+ if attn.norm_added_k is not None:
1455
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1456
+
1457
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
1458
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
1459
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
1460
+
1084
1461
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1085
1462
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1086
1463
  hidden_states = hidden_states.to(query.dtype)
1087
1464
 
1088
- # Split the attention outputs.
1089
- hidden_states, encoder_hidden_states = (
1090
- hidden_states[:, : residual.shape[1]],
1091
- hidden_states[:, residual.shape[1] :],
1092
- )
1465
+ if encoder_hidden_states is not None:
1466
+ # Split the attention outputs.
1467
+ hidden_states, encoder_hidden_states = (
1468
+ hidden_states[:, : residual.shape[1]],
1469
+ hidden_states[:, residual.shape[1] :],
1470
+ )
1471
+ if not attn.context_pre_only:
1472
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1093
1473
 
1094
1474
  # linear proj
1095
1475
  hidden_states = attn.to_out[0](hidden_states)
1096
1476
  # dropout
1097
1477
  hidden_states = attn.to_out[1](hidden_states)
1098
- if not attn.context_pre_only:
1099
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1100
-
1101
- if input_ndim == 4:
1102
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1103
- if context_input_ndim == 4:
1104
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1105
1478
 
1106
- return hidden_states, encoder_hidden_states
1479
+ if encoder_hidden_states is not None:
1480
+ return hidden_states, encoder_hidden_states
1481
+ else:
1482
+ return hidden_states
1107
1483
 
1108
1484
 
1109
1485
  class PAGJointAttnProcessor2_0:
@@ -1120,6 +1496,7 @@ class PAGJointAttnProcessor2_0:
1120
1496
  attn: Attention,
1121
1497
  hidden_states: torch.FloatTensor,
1122
1498
  encoder_hidden_states: torch.FloatTensor = None,
1499
+ attention_mask: Optional[torch.FloatTensor] = None,
1123
1500
  ) -> torch.FloatTensor:
1124
1501
  residual = hidden_states
1125
1502
 
@@ -1505,34 +1882,213 @@ class FusedJointAttnProcessor2_0:
1505
1882
  return hidden_states, encoder_hidden_states
1506
1883
 
1507
1884
 
1508
- class AuraFlowAttnProcessor2_0:
1509
- """Attention processor used typically in processing Aura Flow."""
1885
+ class XFormersJointAttnProcessor:
1886
+ r"""
1887
+ Processor for implementing memory efficient attention using xFormers.
1510
1888
 
1511
- def __init__(self):
1512
- if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
1513
- raise ImportError(
1514
- "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
1515
- )
1889
+ Args:
1890
+ attention_op (`Callable`, *optional*, defaults to `None`):
1891
+ The base
1892
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1893
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1894
+ operator.
1895
+ """
1896
+
1897
+ def __init__(self, attention_op: Optional[Callable] = None):
1898
+ self.attention_op = attention_op
1516
1899
 
1517
1900
  def __call__(
1518
1901
  self,
1519
1902
  attn: Attention,
1520
1903
  hidden_states: torch.FloatTensor,
1521
1904
  encoder_hidden_states: torch.FloatTensor = None,
1905
+ attention_mask: Optional[torch.FloatTensor] = None,
1522
1906
  *args,
1523
1907
  **kwargs,
1524
1908
  ) -> torch.FloatTensor:
1525
- batch_size = hidden_states.shape[0]
1909
+ residual = hidden_states
1526
1910
 
1527
1911
  # `sample` projections.
1528
1912
  query = attn.to_q(hidden_states)
1529
1913
  key = attn.to_k(hidden_states)
1530
1914
  value = attn.to_v(hidden_states)
1531
1915
 
1532
- # `context` projections.
1533
- if encoder_hidden_states is not None:
1534
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1535
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1916
+ query = attn.head_to_batch_dim(query).contiguous()
1917
+ key = attn.head_to_batch_dim(key).contiguous()
1918
+ value = attn.head_to_batch_dim(value).contiguous()
1919
+
1920
+ if attn.norm_q is not None:
1921
+ query = attn.norm_q(query)
1922
+ if attn.norm_k is not None:
1923
+ key = attn.norm_k(key)
1924
+
1925
+ # `context` projections.
1926
+ if encoder_hidden_states is not None:
1927
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1928
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1929
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1930
+
1931
+ encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
1932
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
1933
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
1934
+
1935
+ if attn.norm_added_q is not None:
1936
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1937
+ if attn.norm_added_k is not None:
1938
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1939
+
1940
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1941
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1942
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1943
+
1944
+ hidden_states = xformers.ops.memory_efficient_attention(
1945
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1946
+ )
1947
+ hidden_states = hidden_states.to(query.dtype)
1948
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1949
+
1950
+ if encoder_hidden_states is not None:
1951
+ # Split the attention outputs.
1952
+ hidden_states, encoder_hidden_states = (
1953
+ hidden_states[:, : residual.shape[1]],
1954
+ hidden_states[:, residual.shape[1] :],
1955
+ )
1956
+ if not attn.context_pre_only:
1957
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1958
+
1959
+ # linear proj
1960
+ hidden_states = attn.to_out[0](hidden_states)
1961
+ # dropout
1962
+ hidden_states = attn.to_out[1](hidden_states)
1963
+
1964
+ if encoder_hidden_states is not None:
1965
+ return hidden_states, encoder_hidden_states
1966
+ else:
1967
+ return hidden_states
1968
+
1969
+
1970
+ class AllegroAttnProcessor2_0:
1971
+ r"""
1972
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
1973
+ used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector.
1974
+ """
1975
+
1976
+ def __init__(self):
1977
+ if not hasattr(F, "scaled_dot_product_attention"):
1978
+ raise ImportError(
1979
+ "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
1980
+ )
1981
+
1982
+ def __call__(
1983
+ self,
1984
+ attn: Attention,
1985
+ hidden_states: torch.Tensor,
1986
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1987
+ attention_mask: Optional[torch.Tensor] = None,
1988
+ temb: Optional[torch.Tensor] = None,
1989
+ image_rotary_emb: Optional[torch.Tensor] = None,
1990
+ ) -> torch.Tensor:
1991
+ residual = hidden_states
1992
+
1993
+ if attn.spatial_norm is not None:
1994
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1995
+
1996
+ input_ndim = hidden_states.ndim
1997
+
1998
+ if input_ndim == 4:
1999
+ batch_size, channel, height, width = hidden_states.shape
2000
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2001
+
2002
+ batch_size, sequence_length, _ = (
2003
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2004
+ )
2005
+
2006
+ if attention_mask is not None:
2007
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2008
+ # scaled_dot_product_attention expects attention_mask shape to be
2009
+ # (batch, heads, source_length, target_length)
2010
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2011
+
2012
+ if attn.group_norm is not None:
2013
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2014
+
2015
+ query = attn.to_q(hidden_states)
2016
+
2017
+ if encoder_hidden_states is None:
2018
+ encoder_hidden_states = hidden_states
2019
+ elif attn.norm_cross:
2020
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2021
+
2022
+ key = attn.to_k(encoder_hidden_states)
2023
+ value = attn.to_v(encoder_hidden_states)
2024
+
2025
+ inner_dim = key.shape[-1]
2026
+ head_dim = inner_dim // attn.heads
2027
+
2028
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2029
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2030
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2031
+
2032
+ # Apply RoPE if needed
2033
+ if image_rotary_emb is not None and not attn.is_cross_attention:
2034
+ from .embeddings import apply_rotary_emb_allegro
2035
+
2036
+ query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1])
2037
+ key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1])
2038
+
2039
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2040
+ # TODO: add support for attn.scale when we move to Torch 2.1
2041
+ hidden_states = F.scaled_dot_product_attention(
2042
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2043
+ )
2044
+
2045
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2046
+ hidden_states = hidden_states.to(query.dtype)
2047
+
2048
+ # linear proj
2049
+ hidden_states = attn.to_out[0](hidden_states)
2050
+ # dropout
2051
+ hidden_states = attn.to_out[1](hidden_states)
2052
+
2053
+ if input_ndim == 4:
2054
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2055
+
2056
+ if attn.residual_connection:
2057
+ hidden_states = hidden_states + residual
2058
+
2059
+ hidden_states = hidden_states / attn.rescale_output_factor
2060
+
2061
+ return hidden_states
2062
+
2063
+
2064
+ class AuraFlowAttnProcessor2_0:
2065
+ """Attention processor used typically in processing Aura Flow."""
2066
+
2067
+ def __init__(self):
2068
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
2069
+ raise ImportError(
2070
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
2071
+ )
2072
+
2073
+ def __call__(
2074
+ self,
2075
+ attn: Attention,
2076
+ hidden_states: torch.FloatTensor,
2077
+ encoder_hidden_states: torch.FloatTensor = None,
2078
+ *args,
2079
+ **kwargs,
2080
+ ) -> torch.FloatTensor:
2081
+ batch_size = hidden_states.shape[0]
2082
+
2083
+ # `sample` projections.
2084
+ query = attn.to_q(hidden_states)
2085
+ key = attn.to_k(hidden_states)
2086
+ value = attn.to_v(hidden_states)
2087
+
2088
+ # `context` projections.
2089
+ if encoder_hidden_states is not None:
2090
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2091
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1536
2092
  encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1537
2093
 
1538
2094
  # Reshape.
@@ -1695,52 +2251,231 @@ class FusedAuraFlowAttnProcessor2_0:
1695
2251
  return hidden_states
1696
2252
 
1697
2253
 
1698
- # YiYi to-do: refactor rope related functions/classes
1699
- def apply_rope(xq, xk, freqs_cis):
1700
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
1701
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
1702
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
1703
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
1704
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
1705
-
1706
-
1707
- class FluxSingleAttnProcessor2_0:
1708
- r"""
1709
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1710
- """
2254
+ class FluxAttnProcessor2_0:
2255
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1711
2256
 
1712
2257
  def __init__(self):
1713
2258
  if not hasattr(F, "scaled_dot_product_attention"):
1714
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2259
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1715
2260
 
1716
2261
  def __call__(
1717
2262
  self,
1718
2263
  attn: Attention,
1719
- hidden_states: torch.Tensor,
1720
- encoder_hidden_states: Optional[torch.Tensor] = None,
2264
+ hidden_states: torch.FloatTensor,
2265
+ encoder_hidden_states: torch.FloatTensor = None,
1721
2266
  attention_mask: Optional[torch.FloatTensor] = None,
1722
2267
  image_rotary_emb: Optional[torch.Tensor] = None,
1723
- ) -> torch.Tensor:
1724
- input_ndim = hidden_states.ndim
2268
+ ) -> torch.FloatTensor:
2269
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1725
2270
 
1726
- if input_ndim == 4:
1727
- batch_size, channel, height, width = hidden_states.shape
1728
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2271
+ # `sample` projections.
2272
+ query = attn.to_q(hidden_states)
2273
+ key = attn.to_k(hidden_states)
2274
+ value = attn.to_v(hidden_states)
2275
+
2276
+ inner_dim = key.shape[-1]
2277
+ head_dim = inner_dim // attn.heads
2278
+
2279
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2280
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2281
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1729
2282
 
2283
+ if attn.norm_q is not None:
2284
+ query = attn.norm_q(query)
2285
+ if attn.norm_k is not None:
2286
+ key = attn.norm_k(key)
2287
+
2288
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2289
+ if encoder_hidden_states is not None:
2290
+ # `context` projections.
2291
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2292
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2293
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2294
+
2295
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2296
+ batch_size, -1, attn.heads, head_dim
2297
+ ).transpose(1, 2)
2298
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2299
+ batch_size, -1, attn.heads, head_dim
2300
+ ).transpose(1, 2)
2301
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2302
+ batch_size, -1, attn.heads, head_dim
2303
+ ).transpose(1, 2)
2304
+
2305
+ if attn.norm_added_q is not None:
2306
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2307
+ if attn.norm_added_k is not None:
2308
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2309
+
2310
+ # attention
2311
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2312
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2313
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2314
+
2315
+ if image_rotary_emb is not None:
2316
+ from .embeddings import apply_rotary_emb
2317
+
2318
+ query = apply_rotary_emb(query, image_rotary_emb)
2319
+ key = apply_rotary_emb(key, image_rotary_emb)
2320
+
2321
+ hidden_states = F.scaled_dot_product_attention(
2322
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2323
+ )
2324
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2325
+ hidden_states = hidden_states.to(query.dtype)
2326
+
2327
+ if encoder_hidden_states is not None:
2328
+ encoder_hidden_states, hidden_states = (
2329
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2330
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2331
+ )
2332
+
2333
+ # linear proj
2334
+ hidden_states = attn.to_out[0](hidden_states)
2335
+ # dropout
2336
+ hidden_states = attn.to_out[1](hidden_states)
2337
+
2338
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2339
+
2340
+ return hidden_states, encoder_hidden_states
2341
+ else:
2342
+ return hidden_states
2343
+
2344
+
2345
+ class FluxAttnProcessor2_0_NPU:
2346
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2347
+
2348
+ def __init__(self):
2349
+ if not hasattr(F, "scaled_dot_product_attention"):
2350
+ raise ImportError(
2351
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
2352
+ )
2353
+
2354
+ def __call__(
2355
+ self,
2356
+ attn: Attention,
2357
+ hidden_states: torch.FloatTensor,
2358
+ encoder_hidden_states: torch.FloatTensor = None,
2359
+ attention_mask: Optional[torch.FloatTensor] = None,
2360
+ image_rotary_emb: Optional[torch.Tensor] = None,
2361
+ ) -> torch.FloatTensor:
1730
2362
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1731
2363
 
2364
+ # `sample` projections.
1732
2365
  query = attn.to_q(hidden_states)
1733
- if encoder_hidden_states is None:
1734
- encoder_hidden_states = hidden_states
1735
-
1736
- key = attn.to_k(encoder_hidden_states)
1737
- value = attn.to_v(encoder_hidden_states)
2366
+ key = attn.to_k(hidden_states)
2367
+ value = attn.to_v(hidden_states)
1738
2368
 
1739
2369
  inner_dim = key.shape[-1]
1740
2370
  head_dim = inner_dim // attn.heads
1741
2371
 
1742
2372
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2373
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2374
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2375
+
2376
+ if attn.norm_q is not None:
2377
+ query = attn.norm_q(query)
2378
+ if attn.norm_k is not None:
2379
+ key = attn.norm_k(key)
2380
+
2381
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2382
+ if encoder_hidden_states is not None:
2383
+ # `context` projections.
2384
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2385
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2386
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2387
+
2388
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2389
+ batch_size, -1, attn.heads, head_dim
2390
+ ).transpose(1, 2)
2391
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2392
+ batch_size, -1, attn.heads, head_dim
2393
+ ).transpose(1, 2)
2394
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2395
+ batch_size, -1, attn.heads, head_dim
2396
+ ).transpose(1, 2)
2397
+
2398
+ if attn.norm_added_q is not None:
2399
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2400
+ if attn.norm_added_k is not None:
2401
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2402
+
2403
+ # attention
2404
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2405
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2406
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2407
+
2408
+ if image_rotary_emb is not None:
2409
+ from .embeddings import apply_rotary_emb
2410
+
2411
+ query = apply_rotary_emb(query, image_rotary_emb)
2412
+ key = apply_rotary_emb(key, image_rotary_emb)
2413
+
2414
+ if query.dtype in (torch.float16, torch.bfloat16):
2415
+ hidden_states = torch_npu.npu_fusion_attention(
2416
+ query,
2417
+ key,
2418
+ value,
2419
+ attn.heads,
2420
+ input_layout="BNSD",
2421
+ pse=None,
2422
+ scale=1.0 / math.sqrt(query.shape[-1]),
2423
+ pre_tockens=65536,
2424
+ next_tockens=65536,
2425
+ keep_prob=1.0,
2426
+ sync=False,
2427
+ inner_precise=0,
2428
+ )[0]
2429
+ else:
2430
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2431
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2432
+ hidden_states = hidden_states.to(query.dtype)
2433
+
2434
+ if encoder_hidden_states is not None:
2435
+ encoder_hidden_states, hidden_states = (
2436
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2437
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2438
+ )
1743
2439
 
2440
+ # linear proj
2441
+ hidden_states = attn.to_out[0](hidden_states)
2442
+ # dropout
2443
+ hidden_states = attn.to_out[1](hidden_states)
2444
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2445
+
2446
+ return hidden_states, encoder_hidden_states
2447
+ else:
2448
+ return hidden_states
2449
+
2450
+
2451
+ class FusedFluxAttnProcessor2_0:
2452
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2453
+
2454
+ def __init__(self):
2455
+ if not hasattr(F, "scaled_dot_product_attention"):
2456
+ raise ImportError(
2457
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2458
+ )
2459
+
2460
+ def __call__(
2461
+ self,
2462
+ attn: Attention,
2463
+ hidden_states: torch.FloatTensor,
2464
+ encoder_hidden_states: torch.FloatTensor = None,
2465
+ attention_mask: Optional[torch.FloatTensor] = None,
2466
+ image_rotary_emb: Optional[torch.Tensor] = None,
2467
+ ) -> torch.FloatTensor:
2468
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2469
+
2470
+ # `sample` projections.
2471
+ qkv = attn.to_qkv(hidden_states)
2472
+ split_size = qkv.shape[-1] // 3
2473
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2474
+
2475
+ inner_dim = key.shape[-1]
2476
+ head_dim = inner_dim // attn.heads
2477
+
2478
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1744
2479
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1745
2480
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1746
2481
 
@@ -1749,33 +2484,212 @@ class FluxSingleAttnProcessor2_0:
1749
2484
  if attn.norm_k is not None:
1750
2485
  key = attn.norm_k(key)
1751
2486
 
1752
- # Apply RoPE if needed
2487
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2488
+ # `context` projections.
2489
+ if encoder_hidden_states is not None:
2490
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2491
+ split_size = encoder_qkv.shape[-1] // 3
2492
+ (
2493
+ encoder_hidden_states_query_proj,
2494
+ encoder_hidden_states_key_proj,
2495
+ encoder_hidden_states_value_proj,
2496
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
2497
+
2498
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2499
+ batch_size, -1, attn.heads, head_dim
2500
+ ).transpose(1, 2)
2501
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2502
+ batch_size, -1, attn.heads, head_dim
2503
+ ).transpose(1, 2)
2504
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2505
+ batch_size, -1, attn.heads, head_dim
2506
+ ).transpose(1, 2)
2507
+
2508
+ if attn.norm_added_q is not None:
2509
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2510
+ if attn.norm_added_k is not None:
2511
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2512
+
2513
+ # attention
2514
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2515
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2516
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2517
+
1753
2518
  if image_rotary_emb is not None:
1754
- # YiYi to-do: update uising apply_rotary_emb
1755
- # from ..embeddings import apply_rotary_emb
1756
- # query = apply_rotary_emb(query, image_rotary_emb)
1757
- # key = apply_rotary_emb(key, image_rotary_emb)
1758
- query, key = apply_rope(query, key, image_rotary_emb)
2519
+ from .embeddings import apply_rotary_emb
2520
+
2521
+ query = apply_rotary_emb(query, image_rotary_emb)
2522
+ key = apply_rotary_emb(key, image_rotary_emb)
1759
2523
 
1760
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1761
- # TODO: add support for attn.scale when we move to Torch 2.1
1762
2524
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2525
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2526
+ hidden_states = hidden_states.to(query.dtype)
2527
+
2528
+ if encoder_hidden_states is not None:
2529
+ encoder_hidden_states, hidden_states = (
2530
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2531
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2532
+ )
2533
+
2534
+ # linear proj
2535
+ hidden_states = attn.to_out[0](hidden_states)
2536
+ # dropout
2537
+ hidden_states = attn.to_out[1](hidden_states)
2538
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2539
+
2540
+ return hidden_states, encoder_hidden_states
2541
+ else:
2542
+ return hidden_states
2543
+
2544
+
2545
+ class FusedFluxAttnProcessor2_0_NPU:
2546
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2547
+
2548
+ def __init__(self):
2549
+ if not hasattr(F, "scaled_dot_product_attention"):
2550
+ raise ImportError(
2551
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2552
+ )
2553
+
2554
+ def __call__(
2555
+ self,
2556
+ attn: Attention,
2557
+ hidden_states: torch.FloatTensor,
2558
+ encoder_hidden_states: torch.FloatTensor = None,
2559
+ attention_mask: Optional[torch.FloatTensor] = None,
2560
+ image_rotary_emb: Optional[torch.Tensor] = None,
2561
+ ) -> torch.FloatTensor:
2562
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2563
+
2564
+ # `sample` projections.
2565
+ qkv = attn.to_qkv(hidden_states)
2566
+ split_size = qkv.shape[-1] // 3
2567
+ query, key, value = torch.split(qkv, split_size, dim=-1)
2568
+
2569
+ inner_dim = key.shape[-1]
2570
+ head_dim = inner_dim // attn.heads
2571
+
2572
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2573
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2574
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2575
+
2576
+ if attn.norm_q is not None:
2577
+ query = attn.norm_q(query)
2578
+ if attn.norm_k is not None:
2579
+ key = attn.norm_k(key)
2580
+
2581
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2582
+ # `context` projections.
2583
+ if encoder_hidden_states is not None:
2584
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2585
+ split_size = encoder_qkv.shape[-1] // 3
2586
+ (
2587
+ encoder_hidden_states_query_proj,
2588
+ encoder_hidden_states_key_proj,
2589
+ encoder_hidden_states_value_proj,
2590
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
2591
+
2592
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2593
+ batch_size, -1, attn.heads, head_dim
2594
+ ).transpose(1, 2)
2595
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2596
+ batch_size, -1, attn.heads, head_dim
2597
+ ).transpose(1, 2)
2598
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2599
+ batch_size, -1, attn.heads, head_dim
2600
+ ).transpose(1, 2)
2601
+
2602
+ if attn.norm_added_q is not None:
2603
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2604
+ if attn.norm_added_k is not None:
2605
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2606
+
2607
+ # attention
2608
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2609
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2610
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2611
+
2612
+ if image_rotary_emb is not None:
2613
+ from .embeddings import apply_rotary_emb
2614
+
2615
+ query = apply_rotary_emb(query, image_rotary_emb)
2616
+ key = apply_rotary_emb(key, image_rotary_emb)
2617
+
2618
+ if query.dtype in (torch.float16, torch.bfloat16):
2619
+ hidden_states = torch_npu.npu_fusion_attention(
2620
+ query,
2621
+ key,
2622
+ value,
2623
+ attn.heads,
2624
+ input_layout="BNSD",
2625
+ pse=None,
2626
+ scale=1.0 / math.sqrt(query.shape[-1]),
2627
+ pre_tockens=65536,
2628
+ next_tockens=65536,
2629
+ keep_prob=1.0,
2630
+ sync=False,
2631
+ inner_precise=0,
2632
+ )[0]
2633
+ else:
2634
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1763
2635
 
1764
2636
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1765
2637
  hidden_states = hidden_states.to(query.dtype)
1766
2638
 
1767
- if input_ndim == 4:
1768
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2639
+ if encoder_hidden_states is not None:
2640
+ encoder_hidden_states, hidden_states = (
2641
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2642
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2643
+ )
1769
2644
 
1770
- return hidden_states
2645
+ # linear proj
2646
+ hidden_states = attn.to_out[0](hidden_states)
2647
+ # dropout
2648
+ hidden_states = attn.to_out[1](hidden_states)
2649
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2650
+
2651
+ return hidden_states, encoder_hidden_states
2652
+ else:
2653
+ return hidden_states
2654
+
2655
+
2656
+ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2657
+ """Flux Attention processor for IP-Adapter."""
2658
+
2659
+ def __init__(
2660
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
2661
+ ):
2662
+ super().__init__()
2663
+
2664
+ if not hasattr(F, "scaled_dot_product_attention"):
2665
+ raise ImportError(
2666
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2667
+ )
2668
+
2669
+ self.hidden_size = hidden_size
2670
+ self.cross_attention_dim = cross_attention_dim
1771
2671
 
2672
+ if not isinstance(num_tokens, (tuple, list)):
2673
+ num_tokens = [num_tokens]
1772
2674
 
1773
- class FluxAttnProcessor2_0:
1774
- """Attention processor used typically in processing the SD3-like self-attention projections."""
2675
+ if not isinstance(scale, list):
2676
+ scale = [scale] * len(num_tokens)
2677
+ if len(scale) != len(num_tokens):
2678
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
2679
+ self.scale = scale
1775
2680
 
1776
- def __init__(self):
1777
- if not hasattr(F, "scaled_dot_product_attention"):
1778
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2681
+ self.to_k_ip = nn.ModuleList(
2682
+ [
2683
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2684
+ for _ in range(len(num_tokens))
2685
+ ]
2686
+ )
2687
+ self.to_v_ip = nn.ModuleList(
2688
+ [
2689
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
2690
+ for _ in range(len(num_tokens))
2691
+ ]
2692
+ )
1779
2693
 
1780
2694
  def __call__(
1781
2695
  self,
@@ -1784,88 +2698,102 @@ class FluxAttnProcessor2_0:
1784
2698
  encoder_hidden_states: torch.FloatTensor = None,
1785
2699
  attention_mask: Optional[torch.FloatTensor] = None,
1786
2700
  image_rotary_emb: Optional[torch.Tensor] = None,
2701
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
2702
+ ip_adapter_masks: Optional[torch.Tensor] = None,
1787
2703
  ) -> torch.FloatTensor:
1788
- input_ndim = hidden_states.ndim
1789
- if input_ndim == 4:
1790
- batch_size, channel, height, width = hidden_states.shape
1791
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1792
- context_input_ndim = encoder_hidden_states.ndim
1793
- if context_input_ndim == 4:
1794
- batch_size, channel, height, width = encoder_hidden_states.shape
1795
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1796
-
1797
- batch_size = encoder_hidden_states.shape[0]
2704
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1798
2705
 
1799
2706
  # `sample` projections.
1800
- query = attn.to_q(hidden_states)
2707
+ hidden_states_query_proj = attn.to_q(hidden_states)
1801
2708
  key = attn.to_k(hidden_states)
1802
2709
  value = attn.to_v(hidden_states)
1803
2710
 
1804
2711
  inner_dim = key.shape[-1]
1805
2712
  head_dim = inner_dim // attn.heads
1806
2713
 
1807
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2714
+ hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1808
2715
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1809
2716
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1810
2717
 
1811
2718
  if attn.norm_q is not None:
1812
- query = attn.norm_q(query)
2719
+ hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
1813
2720
  if attn.norm_k is not None:
1814
2721
  key = attn.norm_k(key)
1815
2722
 
1816
- # `context` projections.
1817
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1818
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1819
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
2723
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2724
+ if encoder_hidden_states is not None:
2725
+ # `context` projections.
2726
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
2727
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
2728
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1820
2729
 
1821
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1822
- batch_size, -1, attn.heads, head_dim
1823
- ).transpose(1, 2)
1824
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1825
- batch_size, -1, attn.heads, head_dim
1826
- ).transpose(1, 2)
1827
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1828
- batch_size, -1, attn.heads, head_dim
1829
- ).transpose(1, 2)
2730
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2731
+ batch_size, -1, attn.heads, head_dim
2732
+ ).transpose(1, 2)
2733
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2734
+ batch_size, -1, attn.heads, head_dim
2735
+ ).transpose(1, 2)
2736
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2737
+ batch_size, -1, attn.heads, head_dim
2738
+ ).transpose(1, 2)
1830
2739
 
1831
- if attn.norm_added_q is not None:
1832
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1833
- if attn.norm_added_k is not None:
1834
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2740
+ if attn.norm_added_q is not None:
2741
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2742
+ if attn.norm_added_k is not None:
2743
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1835
2744
 
1836
- # attention
1837
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1838
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1839
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2745
+ # attention
2746
+ query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
2747
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2748
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1840
2749
 
1841
2750
  if image_rotary_emb is not None:
1842
- # YiYi to-do: update uising apply_rotary_emb
1843
- # from ..embeddings import apply_rotary_emb
1844
- # query = apply_rotary_emb(query, image_rotary_emb)
1845
- # key = apply_rotary_emb(key, image_rotary_emb)
1846
- query, key = apply_rope(query, key, image_rotary_emb)
2751
+ from .embeddings import apply_rotary_emb
2752
+
2753
+ query = apply_rotary_emb(query, image_rotary_emb)
2754
+ key = apply_rotary_emb(key, image_rotary_emb)
1847
2755
 
1848
2756
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1849
2757
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1850
2758
  hidden_states = hidden_states.to(query.dtype)
1851
2759
 
1852
- encoder_hidden_states, hidden_states = (
1853
- hidden_states[:, : encoder_hidden_states.shape[1]],
1854
- hidden_states[:, encoder_hidden_states.shape[1] :],
1855
- )
2760
+ if encoder_hidden_states is not None:
2761
+ encoder_hidden_states, hidden_states = (
2762
+ hidden_states[:, : encoder_hidden_states.shape[1]],
2763
+ hidden_states[:, encoder_hidden_states.shape[1] :],
2764
+ )
1856
2765
 
1857
- # linear proj
1858
- hidden_states = attn.to_out[0](hidden_states)
1859
- # dropout
1860
- hidden_states = attn.to_out[1](hidden_states)
1861
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2766
+ # linear proj
2767
+ hidden_states = attn.to_out[0](hidden_states)
2768
+ # dropout
2769
+ hidden_states = attn.to_out[1](hidden_states)
2770
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1862
2771
 
1863
- if input_ndim == 4:
1864
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1865
- if context_input_ndim == 4:
1866
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2772
+ # IP-adapter
2773
+ ip_query = hidden_states_query_proj
2774
+ ip_attn_output = None
2775
+ # for ip-adapter
2776
+ # TODO: support for multiple adapters
2777
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2778
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2779
+ ):
2780
+ ip_key = to_k_ip(current_ip_hidden_states)
2781
+ ip_value = to_v_ip(current_ip_hidden_states)
2782
+
2783
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2784
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2785
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2786
+ # TODO: add support for attn.scale when we move to Torch 2.1
2787
+ ip_attn_output = F.scaled_dot_product_attention(
2788
+ ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2789
+ )
2790
+ ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2791
+ ip_attn_output = scale * ip_attn_output
2792
+ ip_attn_output = ip_attn_output.to(ip_query.dtype)
1867
2793
 
1868
- return hidden_states, encoder_hidden_states
2794
+ return hidden_states, encoder_hidden_states, ip_attn_output
2795
+ else:
2796
+ return hidden_states
1869
2797
 
1870
2798
 
1871
2799
  class CogVideoXAttnProcessor2_0:
@@ -2260,7 +3188,217 @@ class AttnProcessorNPU:
2260
3188
  inner_precise=0,
2261
3189
  )[0]
2262
3190
  else:
2263
- # TODO: add support for attn.scale when we move to Torch 2.1
3191
+ # TODO: add support for attn.scale when we move to Torch 2.1
3192
+ hidden_states = F.scaled_dot_product_attention(
3193
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3194
+ )
3195
+
3196
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3197
+ hidden_states = hidden_states.to(query.dtype)
3198
+
3199
+ # linear proj
3200
+ hidden_states = attn.to_out[0](hidden_states)
3201
+ # dropout
3202
+ hidden_states = attn.to_out[1](hidden_states)
3203
+
3204
+ if input_ndim == 4:
3205
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3206
+
3207
+ if attn.residual_connection:
3208
+ hidden_states = hidden_states + residual
3209
+
3210
+ hidden_states = hidden_states / attn.rescale_output_factor
3211
+
3212
+ return hidden_states
3213
+
3214
+
3215
+ class AttnProcessor2_0:
3216
+ r"""
3217
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3218
+ """
3219
+
3220
+ def __init__(self):
3221
+ if not hasattr(F, "scaled_dot_product_attention"):
3222
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
3223
+
3224
+ def __call__(
3225
+ self,
3226
+ attn: Attention,
3227
+ hidden_states: torch.Tensor,
3228
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3229
+ attention_mask: Optional[torch.Tensor] = None,
3230
+ temb: Optional[torch.Tensor] = None,
3231
+ *args,
3232
+ **kwargs,
3233
+ ) -> torch.Tensor:
3234
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
3235
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3236
+ deprecate("scale", "1.0.0", deprecation_message)
3237
+
3238
+ residual = hidden_states
3239
+ if attn.spatial_norm is not None:
3240
+ hidden_states = attn.spatial_norm(hidden_states, temb)
3241
+
3242
+ input_ndim = hidden_states.ndim
3243
+
3244
+ if input_ndim == 4:
3245
+ batch_size, channel, height, width = hidden_states.shape
3246
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3247
+
3248
+ batch_size, sequence_length, _ = (
3249
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3250
+ )
3251
+
3252
+ if attention_mask is not None:
3253
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3254
+ # scaled_dot_product_attention expects attention_mask shape to be
3255
+ # (batch, heads, source_length, target_length)
3256
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3257
+
3258
+ if attn.group_norm is not None:
3259
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3260
+
3261
+ query = attn.to_q(hidden_states)
3262
+
3263
+ if encoder_hidden_states is None:
3264
+ encoder_hidden_states = hidden_states
3265
+ elif attn.norm_cross:
3266
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
3267
+
3268
+ key = attn.to_k(encoder_hidden_states)
3269
+ value = attn.to_v(encoder_hidden_states)
3270
+
3271
+ inner_dim = key.shape[-1]
3272
+ head_dim = inner_dim // attn.heads
3273
+
3274
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3275
+
3276
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3277
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3278
+
3279
+ if attn.norm_q is not None:
3280
+ query = attn.norm_q(query)
3281
+ if attn.norm_k is not None:
3282
+ key = attn.norm_k(key)
3283
+
3284
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3285
+ # TODO: add support for attn.scale when we move to Torch 2.1
3286
+ hidden_states = F.scaled_dot_product_attention(
3287
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3288
+ )
3289
+
3290
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3291
+ hidden_states = hidden_states.to(query.dtype)
3292
+
3293
+ # linear proj
3294
+ hidden_states = attn.to_out[0](hidden_states)
3295
+ # dropout
3296
+ hidden_states = attn.to_out[1](hidden_states)
3297
+
3298
+ if input_ndim == 4:
3299
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3300
+
3301
+ if attn.residual_connection:
3302
+ hidden_states = hidden_states + residual
3303
+
3304
+ hidden_states = hidden_states / attn.rescale_output_factor
3305
+
3306
+ return hidden_states
3307
+
3308
+
3309
+ class XLAFlashAttnProcessor2_0:
3310
+ r"""
3311
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3312
+ """
3313
+
3314
+ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3315
+ if not hasattr(F, "scaled_dot_product_attention"):
3316
+ raise ImportError(
3317
+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3318
+ )
3319
+ if is_torch_xla_version("<", "2.3"):
3320
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3321
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
3322
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3323
+ self.partition_spec = partition_spec
3324
+
3325
+ def __call__(
3326
+ self,
3327
+ attn: Attention,
3328
+ hidden_states: torch.Tensor,
3329
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3330
+ attention_mask: Optional[torch.Tensor] = None,
3331
+ temb: Optional[torch.Tensor] = None,
3332
+ *args,
3333
+ **kwargs,
3334
+ ) -> torch.Tensor:
3335
+ residual = hidden_states
3336
+ if attn.spatial_norm is not None:
3337
+ hidden_states = attn.spatial_norm(hidden_states, temb)
3338
+
3339
+ input_ndim = hidden_states.ndim
3340
+
3341
+ if input_ndim == 4:
3342
+ batch_size, channel, height, width = hidden_states.shape
3343
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3344
+
3345
+ batch_size, sequence_length, _ = (
3346
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3347
+ )
3348
+
3349
+ if attention_mask is not None:
3350
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3351
+ # scaled_dot_product_attention expects attention_mask shape to be
3352
+ # (batch, heads, source_length, target_length)
3353
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3354
+
3355
+ if attn.group_norm is not None:
3356
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3357
+
3358
+ query = attn.to_q(hidden_states)
3359
+
3360
+ if encoder_hidden_states is None:
3361
+ encoder_hidden_states = hidden_states
3362
+ elif attn.norm_cross:
3363
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
3364
+
3365
+ key = attn.to_k(encoder_hidden_states)
3366
+ value = attn.to_v(encoder_hidden_states)
3367
+
3368
+ inner_dim = key.shape[-1]
3369
+ head_dim = inner_dim // attn.heads
3370
+
3371
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3372
+
3373
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3374
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3375
+
3376
+ if attn.norm_q is not None:
3377
+ query = attn.norm_q(query)
3378
+ if attn.norm_k is not None:
3379
+ key = attn.norm_k(key)
3380
+
3381
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
3382
+ # TODO: add support for attn.scale when we move to Torch 2.1
3383
+ if all(tensor.shape[2] >= 4096 for tensor in [query, key, value]):
3384
+ if attention_mask is not None:
3385
+ attention_mask = attention_mask.view(batch_size, 1, 1, attention_mask.shape[-1])
3386
+ # Convert mask to float and replace 0s with -inf and 1s with 0
3387
+ attention_mask = (
3388
+ attention_mask.float()
3389
+ .masked_fill(attention_mask == 0, float("-inf"))
3390
+ .masked_fill(attention_mask == 1, float(0.0))
3391
+ )
3392
+
3393
+ # Apply attention mask to key
3394
+ key = key + attention_mask
3395
+ query /= math.sqrt(query.shape[3])
3396
+ partition_spec = self.partition_spec if is_spmd() else None
3397
+ hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec)
3398
+ else:
3399
+ logger.warning(
3400
+ "Unable to use the flash attention pallas kernel API call due to QKV sequence length < 4096."
3401
+ )
2264
3402
  hidden_states = F.scaled_dot_product_attention(
2265
3403
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2266
3404
  )
@@ -2284,9 +3422,9 @@ class AttnProcessorNPU:
2284
3422
  return hidden_states
2285
3423
 
2286
3424
 
2287
- class AttnProcessor2_0:
3425
+ class MochiVaeAttnProcessor2_0:
2288
3426
  r"""
2289
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
3427
+ Attention processor used in Mochi VAE.
2290
3428
  """
2291
3429
 
2292
3430
  def __init__(self):
@@ -2299,23 +3437,9 @@ class AttnProcessor2_0:
2299
3437
  hidden_states: torch.Tensor,
2300
3438
  encoder_hidden_states: Optional[torch.Tensor] = None,
2301
3439
  attention_mask: Optional[torch.Tensor] = None,
2302
- temb: Optional[torch.Tensor] = None,
2303
- *args,
2304
- **kwargs,
2305
3440
  ) -> torch.Tensor:
2306
- if len(args) > 0 or kwargs.get("scale", None) is not None:
2307
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2308
- deprecate("scale", "1.0.0", deprecation_message)
2309
-
2310
3441
  residual = hidden_states
2311
- if attn.spatial_norm is not None:
2312
- hidden_states = attn.spatial_norm(hidden_states, temb)
2313
-
2314
- input_ndim = hidden_states.ndim
2315
-
2316
- if input_ndim == 4:
2317
- batch_size, channel, height, width = hidden_states.shape
2318
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
3442
+ is_single_frame = hidden_states.shape[1] == 1
2319
3443
 
2320
3444
  batch_size, sequence_length, _ = (
2321
3445
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
@@ -2327,15 +3451,24 @@ class AttnProcessor2_0:
2327
3451
  # (batch, heads, source_length, target_length)
2328
3452
  attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2329
3453
 
2330
- if attn.group_norm is not None:
2331
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
3454
+ if is_single_frame:
3455
+ hidden_states = attn.to_v(hidden_states)
3456
+
3457
+ # linear proj
3458
+ hidden_states = attn.to_out[0](hidden_states)
3459
+ # dropout
3460
+ hidden_states = attn.to_out[1](hidden_states)
3461
+
3462
+ if attn.residual_connection:
3463
+ hidden_states = hidden_states + residual
3464
+
3465
+ hidden_states = hidden_states / attn.rescale_output_factor
3466
+ return hidden_states
2332
3467
 
2333
3468
  query = attn.to_q(hidden_states)
2334
3469
 
2335
3470
  if encoder_hidden_states is None:
2336
3471
  encoder_hidden_states = hidden_states
2337
- elif attn.norm_cross:
2338
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2339
3472
 
2340
3473
  key = attn.to_k(encoder_hidden_states)
2341
3474
  value = attn.to_v(encoder_hidden_states)
@@ -2344,7 +3477,6 @@ class AttnProcessor2_0:
2344
3477
  head_dim = inner_dim // attn.heads
2345
3478
 
2346
3479
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2347
-
2348
3480
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2349
3481
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2350
3482
 
@@ -2356,7 +3488,7 @@ class AttnProcessor2_0:
2356
3488
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
2357
3489
  # TODO: add support for attn.scale when we move to Torch 2.1
2358
3490
  hidden_states = F.scaled_dot_product_attention(
2359
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3491
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
2360
3492
  )
2361
3493
 
2362
3494
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -2367,9 +3499,6 @@ class AttnProcessor2_0:
2367
3499
  # dropout
2368
3500
  hidden_states = attn.to_out[1](hidden_states)
2369
3501
 
2370
- if input_ndim == 4:
2371
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2372
-
2373
3502
  if attn.residual_connection:
2374
3503
  hidden_states = hidden_states + residual
2375
3504
 
@@ -3572,34 +4701,232 @@ class SpatialNorm(nn.Module):
3572
4701
  """
3573
4702
  Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
3574
4703
 
3575
- Args:
3576
- f_channels (`int`):
3577
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
3578
- zq_channels (`int`):
3579
- The number of channels for the quantized vector as described in the paper.
3580
- """
4704
+ Args:
4705
+ f_channels (`int`):
4706
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
4707
+ zq_channels (`int`):
4708
+ The number of channels for the quantized vector as described in the paper.
4709
+ """
4710
+
4711
+ def __init__(
4712
+ self,
4713
+ f_channels: int,
4714
+ zq_channels: int,
4715
+ ):
4716
+ super().__init__()
4717
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
4718
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4719
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4720
+
4721
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
4722
+ f_size = f.shape[-2:]
4723
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
4724
+ norm_f = self.norm_layer(f)
4725
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
4726
+ return new_f
4727
+
4728
+
4729
+ class IPAdapterAttnProcessor(nn.Module):
4730
+ r"""
4731
+ Attention processor for Multiple IP-Adapters.
4732
+
4733
+ Args:
4734
+ hidden_size (`int`):
4735
+ The hidden size of the attention layer.
4736
+ cross_attention_dim (`int`):
4737
+ The number of channels in the `encoder_hidden_states`.
4738
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
4739
+ The context length of the image features.
4740
+ scale (`float` or List[`float`], defaults to 1.0):
4741
+ the weight scale of image prompt.
4742
+ """
4743
+
4744
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
4745
+ super().__init__()
4746
+
4747
+ self.hidden_size = hidden_size
4748
+ self.cross_attention_dim = cross_attention_dim
4749
+
4750
+ if not isinstance(num_tokens, (tuple, list)):
4751
+ num_tokens = [num_tokens]
4752
+ self.num_tokens = num_tokens
4753
+
4754
+ if not isinstance(scale, list):
4755
+ scale = [scale] * len(num_tokens)
4756
+ if len(scale) != len(num_tokens):
4757
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
4758
+ self.scale = scale
4759
+
4760
+ self.to_k_ip = nn.ModuleList(
4761
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
4762
+ )
4763
+ self.to_v_ip = nn.ModuleList(
4764
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
4765
+ )
4766
+
4767
+ def __call__(
4768
+ self,
4769
+ attn: Attention,
4770
+ hidden_states: torch.Tensor,
4771
+ encoder_hidden_states: Optional[torch.Tensor] = None,
4772
+ attention_mask: Optional[torch.Tensor] = None,
4773
+ temb: Optional[torch.Tensor] = None,
4774
+ scale: float = 1.0,
4775
+ ip_adapter_masks: Optional[torch.Tensor] = None,
4776
+ ):
4777
+ residual = hidden_states
4778
+
4779
+ # separate ip_hidden_states from encoder_hidden_states
4780
+ if encoder_hidden_states is not None:
4781
+ if isinstance(encoder_hidden_states, tuple):
4782
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
4783
+ else:
4784
+ deprecation_message = (
4785
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
4786
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
4787
+ )
4788
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
4789
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
4790
+ encoder_hidden_states, ip_hidden_states = (
4791
+ encoder_hidden_states[:, :end_pos, :],
4792
+ [encoder_hidden_states[:, end_pos:, :]],
4793
+ )
4794
+
4795
+ if attn.spatial_norm is not None:
4796
+ hidden_states = attn.spatial_norm(hidden_states, temb)
4797
+
4798
+ input_ndim = hidden_states.ndim
4799
+
4800
+ if input_ndim == 4:
4801
+ batch_size, channel, height, width = hidden_states.shape
4802
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
4803
+
4804
+ batch_size, sequence_length, _ = (
4805
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
4806
+ )
4807
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
4808
+
4809
+ if attn.group_norm is not None:
4810
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
4811
+
4812
+ query = attn.to_q(hidden_states)
4813
+
4814
+ if encoder_hidden_states is None:
4815
+ encoder_hidden_states = hidden_states
4816
+ elif attn.norm_cross:
4817
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
4818
+
4819
+ key = attn.to_k(encoder_hidden_states)
4820
+ value = attn.to_v(encoder_hidden_states)
4821
+
4822
+ query = attn.head_to_batch_dim(query)
4823
+ key = attn.head_to_batch_dim(key)
4824
+ value = attn.head_to_batch_dim(value)
4825
+
4826
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
4827
+ hidden_states = torch.bmm(attention_probs, value)
4828
+ hidden_states = attn.batch_to_head_dim(hidden_states)
4829
+
4830
+ if ip_adapter_masks is not None:
4831
+ if not isinstance(ip_adapter_masks, List):
4832
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
4833
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
4834
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
4835
+ raise ValueError(
4836
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
4837
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
4838
+ f"({len(ip_hidden_states)})"
4839
+ )
4840
+ else:
4841
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
4842
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
4843
+ raise ValueError(
4844
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
4845
+ "[1, num_images_for_ip_adapter, height, width]."
4846
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
4847
+ )
4848
+ if mask.shape[1] != ip_state.shape[1]:
4849
+ raise ValueError(
4850
+ f"Number of masks ({mask.shape[1]}) does not match "
4851
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
4852
+ )
4853
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
4854
+ raise ValueError(
4855
+ f"Number of masks ({mask.shape[1]}) does not match "
4856
+ f"number of scales ({len(scale)}) at index {index}"
4857
+ )
4858
+ else:
4859
+ ip_adapter_masks = [None] * len(self.scale)
4860
+
4861
+ # for ip-adapter
4862
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
4863
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
4864
+ ):
4865
+ skip = False
4866
+ if isinstance(scale, list):
4867
+ if all(s == 0 for s in scale):
4868
+ skip = True
4869
+ elif scale == 0:
4870
+ skip = True
4871
+ if not skip:
4872
+ if mask is not None:
4873
+ if not isinstance(scale, list):
4874
+ scale = [scale] * mask.shape[1]
4875
+
4876
+ current_num_images = mask.shape[1]
4877
+ for i in range(current_num_images):
4878
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
4879
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
4880
+
4881
+ ip_key = attn.head_to_batch_dim(ip_key)
4882
+ ip_value = attn.head_to_batch_dim(ip_value)
4883
+
4884
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
4885
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
4886
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
4887
+
4888
+ mask_downsample = IPAdapterMaskProcessor.downsample(
4889
+ mask[:, i, :, :],
4890
+ batch_size,
4891
+ _current_ip_hidden_states.shape[1],
4892
+ _current_ip_hidden_states.shape[2],
4893
+ )
4894
+
4895
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
4896
+
4897
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
4898
+ else:
4899
+ ip_key = to_k_ip(current_ip_hidden_states)
4900
+ ip_value = to_v_ip(current_ip_hidden_states)
4901
+
4902
+ ip_key = attn.head_to_batch_dim(ip_key)
4903
+ ip_value = attn.head_to_batch_dim(ip_value)
4904
+
4905
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
4906
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
4907
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
4908
+
4909
+ hidden_states = hidden_states + scale * current_ip_hidden_states
4910
+
4911
+ # linear proj
4912
+ hidden_states = attn.to_out[0](hidden_states)
4913
+ # dropout
4914
+ hidden_states = attn.to_out[1](hidden_states)
4915
+
4916
+ if input_ndim == 4:
4917
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
3581
4918
 
3582
- def __init__(
3583
- self,
3584
- f_channels: int,
3585
- zq_channels: int,
3586
- ):
3587
- super().__init__()
3588
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
3589
- self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
3590
- self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
4919
+ if attn.residual_connection:
4920
+ hidden_states = hidden_states + residual
3591
4921
 
3592
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
3593
- f_size = f.shape[-2:]
3594
- zq = F.interpolate(zq, size=f_size, mode="nearest")
3595
- norm_f = self.norm_layer(f)
3596
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
3597
- return new_f
4922
+ hidden_states = hidden_states / attn.rescale_output_factor
4923
+
4924
+ return hidden_states
3598
4925
 
3599
4926
 
3600
- class IPAdapterAttnProcessor(nn.Module):
4927
+ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3601
4928
  r"""
3602
- Attention processor for Multiple IP-Adapters.
4929
+ Attention processor for IP-Adapter for PyTorch 2.0.
3603
4930
 
3604
4931
  Args:
3605
4932
  hidden_size (`int`):
@@ -3608,13 +4935,18 @@ class IPAdapterAttnProcessor(nn.Module):
3608
4935
  The number of channels in the `encoder_hidden_states`.
3609
4936
  num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
3610
4937
  The context length of the image features.
3611
- scale (`float` or List[`float`], defaults to 1.0):
4938
+ scale (`float` or `List[float]`, defaults to 1.0):
3612
4939
  the weight scale of image prompt.
3613
4940
  """
3614
4941
 
3615
4942
  def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
3616
4943
  super().__init__()
3617
4944
 
4945
+ if not hasattr(F, "scaled_dot_product_attention"):
4946
+ raise ImportError(
4947
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
4948
+ )
4949
+
3618
4950
  self.hidden_size = hidden_size
3619
4951
  self.cross_attention_dim = cross_attention_dim
3620
4952
 
@@ -3675,7 +5007,12 @@ class IPAdapterAttnProcessor(nn.Module):
3675
5007
  batch_size, sequence_length, _ = (
3676
5008
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3677
5009
  )
3678
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
5010
+
5011
+ if attention_mask is not None:
5012
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
5013
+ # scaled_dot_product_attention expects attention_mask shape to be
5014
+ # (batch, heads, source_length, target_length)
5015
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3679
5016
 
3680
5017
  if attn.group_norm is not None:
3681
5018
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3690,13 +5027,22 @@ class IPAdapterAttnProcessor(nn.Module):
3690
5027
  key = attn.to_k(encoder_hidden_states)
3691
5028
  value = attn.to_v(encoder_hidden_states)
3692
5029
 
3693
- query = attn.head_to_batch_dim(query)
3694
- key = attn.head_to_batch_dim(key)
3695
- value = attn.head_to_batch_dim(value)
5030
+ inner_dim = key.shape[-1]
5031
+ head_dim = inner_dim // attn.heads
3696
5032
 
3697
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
3698
- hidden_states = torch.bmm(attention_probs, value)
3699
- hidden_states = attn.batch_to_head_dim(hidden_states)
5033
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5034
+
5035
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5036
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5037
+
5038
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5039
+ # TODO: add support for attn.scale when we move to Torch 2.1
5040
+ hidden_states = F.scaled_dot_product_attention(
5041
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
5042
+ )
5043
+
5044
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
5045
+ hidden_states = hidden_states.to(query.dtype)
3700
5046
 
3701
5047
  if ip_adapter_masks is not None:
3702
5048
  if not isinstance(ip_adapter_masks, List):
@@ -3749,12 +5095,19 @@ class IPAdapterAttnProcessor(nn.Module):
3749
5095
  ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
3750
5096
  ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
3751
5097
 
3752
- ip_key = attn.head_to_batch_dim(ip_key)
3753
- ip_value = attn.head_to_batch_dim(ip_value)
5098
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5099
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3754
5100
 
3755
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
3756
- _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
3757
- _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
5101
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5102
+ # TODO: add support for attn.scale when we move to Torch 2.1
5103
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
5104
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
5105
+ )
5106
+
5107
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
5108
+ batch_size, -1, attn.heads * head_dim
5109
+ )
5110
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
3758
5111
 
3759
5112
  mask_downsample = IPAdapterMaskProcessor.downsample(
3760
5113
  mask[:, i, :, :],
@@ -3764,18 +5117,24 @@ class IPAdapterAttnProcessor(nn.Module):
3764
5117
  )
3765
5118
 
3766
5119
  mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
3767
-
3768
5120
  hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
3769
5121
  else:
3770
5122
  ip_key = to_k_ip(current_ip_hidden_states)
3771
5123
  ip_value = to_v_ip(current_ip_hidden_states)
3772
5124
 
3773
- ip_key = attn.head_to_batch_dim(ip_key)
3774
- ip_value = attn.head_to_batch_dim(ip_value)
5125
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5126
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3775
5127
 
3776
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
3777
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
3778
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
5128
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
5129
+ # TODO: add support for attn.scale when we move to Torch 2.1
5130
+ current_ip_hidden_states = F.scaled_dot_product_attention(
5131
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
5132
+ )
5133
+
5134
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
5135
+ batch_size, -1, attn.heads * head_dim
5136
+ )
5137
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
3779
5138
 
3780
5139
  hidden_states = hidden_states + scale * current_ip_hidden_states
3781
5140
 
@@ -3795,9 +5154,9 @@ class IPAdapterAttnProcessor(nn.Module):
3795
5154
  return hidden_states
3796
5155
 
3797
5156
 
3798
- class IPAdapterAttnProcessor2_0(torch.nn.Module):
5157
+ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
3799
5158
  r"""
3800
- Attention processor for IP-Adapter for PyTorch 2.0.
5159
+ Attention processor for IP-Adapter using xFormers.
3801
5160
 
3802
5161
  Args:
3803
5162
  hidden_size (`int`):
@@ -3808,18 +5167,26 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3808
5167
  The context length of the image features.
3809
5168
  scale (`float` or `List[float]`, defaults to 1.0):
3810
5169
  the weight scale of image prompt.
5170
+ attention_op (`Callable`, *optional*, defaults to `None`):
5171
+ The base
5172
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
5173
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
5174
+ operator.
3811
5175
  """
3812
5176
 
3813
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
5177
+ def __init__(
5178
+ self,
5179
+ hidden_size,
5180
+ cross_attention_dim=None,
5181
+ num_tokens=(4,),
5182
+ scale=1.0,
5183
+ attention_op: Optional[Callable] = None,
5184
+ ):
3814
5185
  super().__init__()
3815
5186
 
3816
- if not hasattr(F, "scaled_dot_product_attention"):
3817
- raise ImportError(
3818
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3819
- )
3820
-
3821
5187
  self.hidden_size = hidden_size
3822
5188
  self.cross_attention_dim = cross_attention_dim
5189
+ self.attention_op = attention_op
3823
5190
 
3824
5191
  if not isinstance(num_tokens, (tuple, list)):
3825
5192
  num_tokens = [num_tokens]
@@ -3832,21 +5199,21 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3832
5199
  self.scale = scale
3833
5200
 
3834
5201
  self.to_k_ip = nn.ModuleList(
3835
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
5202
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
3836
5203
  )
3837
5204
  self.to_v_ip = nn.ModuleList(
3838
- [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
5205
+ [nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
3839
5206
  )
3840
5207
 
3841
5208
  def __call__(
3842
5209
  self,
3843
5210
  attn: Attention,
3844
- hidden_states: torch.Tensor,
3845
- encoder_hidden_states: Optional[torch.Tensor] = None,
3846
- attention_mask: Optional[torch.Tensor] = None,
3847
- temb: Optional[torch.Tensor] = None,
5211
+ hidden_states: torch.FloatTensor,
5212
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
5213
+ attention_mask: Optional[torch.FloatTensor] = None,
5214
+ temb: Optional[torch.FloatTensor] = None,
3848
5215
  scale: float = 1.0,
3849
- ip_adapter_masks: Optional[torch.Tensor] = None,
5216
+ ip_adapter_masks: Optional[torch.FloatTensor] = None,
3850
5217
  ):
3851
5218
  residual = hidden_states
3852
5219
 
@@ -3881,9 +5248,14 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3881
5248
 
3882
5249
  if attention_mask is not None:
3883
5250
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
3884
- # scaled_dot_product_attention expects attention_mask shape to be
3885
- # (batch, heads, source_length, target_length)
3886
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
5251
+ # expand our mask's singleton query_tokens dimension:
5252
+ # [batch*heads, 1, key_tokens] ->
5253
+ # [batch*heads, query_tokens, key_tokens]
5254
+ # so that it can be added as a bias onto the attention scores that xformers computes:
5255
+ # [batch*heads, query_tokens, key_tokens]
5256
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
5257
+ _, query_tokens, _ = hidden_states.shape
5258
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
3887
5259
 
3888
5260
  if attn.group_norm is not None:
3889
5261
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3898,131 +5270,291 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
3898
5270
  key = attn.to_k(encoder_hidden_states)
3899
5271
  value = attn.to_v(encoder_hidden_states)
3900
5272
 
5273
+ query = attn.head_to_batch_dim(query).contiguous()
5274
+ key = attn.head_to_batch_dim(key).contiguous()
5275
+ value = attn.head_to_batch_dim(value).contiguous()
5276
+
5277
+ hidden_states = xformers.ops.memory_efficient_attention(
5278
+ query, key, value, attn_bias=attention_mask, op=self.attention_op
5279
+ )
5280
+ hidden_states = hidden_states.to(query.dtype)
5281
+ hidden_states = attn.batch_to_head_dim(hidden_states)
5282
+
5283
+ if ip_hidden_states:
5284
+ if ip_adapter_masks is not None:
5285
+ if not isinstance(ip_adapter_masks, List):
5286
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
5287
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
5288
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
5289
+ raise ValueError(
5290
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
5291
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
5292
+ f"({len(ip_hidden_states)})"
5293
+ )
5294
+ else:
5295
+ for index, (mask, scale, ip_state) in enumerate(
5296
+ zip(ip_adapter_masks, self.scale, ip_hidden_states)
5297
+ ):
5298
+ if mask is None:
5299
+ continue
5300
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
5301
+ raise ValueError(
5302
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
5303
+ "[1, num_images_for_ip_adapter, height, width]."
5304
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
5305
+ )
5306
+ if mask.shape[1] != ip_state.shape[1]:
5307
+ raise ValueError(
5308
+ f"Number of masks ({mask.shape[1]}) does not match "
5309
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
5310
+ )
5311
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
5312
+ raise ValueError(
5313
+ f"Number of masks ({mask.shape[1]}) does not match "
5314
+ f"number of scales ({len(scale)}) at index {index}"
5315
+ )
5316
+ else:
5317
+ ip_adapter_masks = [None] * len(self.scale)
5318
+
5319
+ # for ip-adapter
5320
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
5321
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
5322
+ ):
5323
+ skip = False
5324
+ if isinstance(scale, list):
5325
+ if all(s == 0 for s in scale):
5326
+ skip = True
5327
+ elif scale == 0:
5328
+ skip = True
5329
+ if not skip:
5330
+ if mask is not None:
5331
+ mask = mask.to(torch.float16)
5332
+ if not isinstance(scale, list):
5333
+ scale = [scale] * mask.shape[1]
5334
+
5335
+ current_num_images = mask.shape[1]
5336
+ for i in range(current_num_images):
5337
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
5338
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
5339
+
5340
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
5341
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
5342
+
5343
+ _current_ip_hidden_states = xformers.ops.memory_efficient_attention(
5344
+ query, ip_key, ip_value, op=self.attention_op
5345
+ )
5346
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
5347
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
5348
+
5349
+ mask_downsample = IPAdapterMaskProcessor.downsample(
5350
+ mask[:, i, :, :],
5351
+ batch_size,
5352
+ _current_ip_hidden_states.shape[1],
5353
+ _current_ip_hidden_states.shape[2],
5354
+ )
5355
+
5356
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
5357
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
5358
+ else:
5359
+ ip_key = to_k_ip(current_ip_hidden_states)
5360
+ ip_value = to_v_ip(current_ip_hidden_states)
5361
+
5362
+ ip_key = attn.head_to_batch_dim(ip_key).contiguous()
5363
+ ip_value = attn.head_to_batch_dim(ip_value).contiguous()
5364
+
5365
+ current_ip_hidden_states = xformers.ops.memory_efficient_attention(
5366
+ query, ip_key, ip_value, op=self.attention_op
5367
+ )
5368
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
5369
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
5370
+
5371
+ hidden_states = hidden_states + scale * current_ip_hidden_states
5372
+
5373
+ # linear proj
5374
+ hidden_states = attn.to_out[0](hidden_states)
5375
+ # dropout
5376
+ hidden_states = attn.to_out[1](hidden_states)
5377
+
5378
+ if input_ndim == 4:
5379
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
5380
+
5381
+ if attn.residual_connection:
5382
+ hidden_states = hidden_states + residual
5383
+
5384
+ hidden_states = hidden_states / attn.rescale_output_factor
5385
+
5386
+ return hidden_states
5387
+
5388
+
5389
+ class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
5390
+ """
5391
+ Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
5392
+ additional image-based information and timestep embeddings.
5393
+
5394
+ Args:
5395
+ hidden_size (`int`):
5396
+ The number of hidden channels.
5397
+ ip_hidden_states_dim (`int`):
5398
+ The image feature dimension.
5399
+ head_dim (`int`):
5400
+ The number of head channels.
5401
+ timesteps_emb_dim (`int`, defaults to 1280):
5402
+ The number of input channels for timestep embedding.
5403
+ scale (`float`, defaults to 0.5):
5404
+ IP-Adapter scale.
5405
+ """
5406
+
5407
+ def __init__(
5408
+ self,
5409
+ hidden_size: int,
5410
+ ip_hidden_states_dim: int,
5411
+ head_dim: int,
5412
+ timesteps_emb_dim: int = 1280,
5413
+ scale: float = 0.5,
5414
+ ):
5415
+ super().__init__()
5416
+
5417
+ # To prevent circular import
5418
+ from .normalization import AdaLayerNorm, RMSNorm
5419
+
5420
+ self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
5421
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
5422
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
5423
+ self.norm_q = RMSNorm(head_dim, 1e-6)
5424
+ self.norm_k = RMSNorm(head_dim, 1e-6)
5425
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
5426
+ self.scale = scale
5427
+
5428
+ def __call__(
5429
+ self,
5430
+ attn: Attention,
5431
+ hidden_states: torch.FloatTensor,
5432
+ encoder_hidden_states: torch.FloatTensor = None,
5433
+ attention_mask: Optional[torch.FloatTensor] = None,
5434
+ ip_hidden_states: torch.FloatTensor = None,
5435
+ temb: torch.FloatTensor = None,
5436
+ ) -> torch.FloatTensor:
5437
+ """
5438
+ Perform the attention computation, integrating image features (if provided) and timestep embeddings.
5439
+
5440
+ If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
5441
+
5442
+ Args:
5443
+ attn (`Attention`):
5444
+ Attention instance.
5445
+ hidden_states (`torch.FloatTensor`):
5446
+ Input `hidden_states`.
5447
+ encoder_hidden_states (`torch.FloatTensor`, *optional*):
5448
+ The encoder hidden states.
5449
+ attention_mask (`torch.FloatTensor`, *optional*):
5450
+ Attention mask.
5451
+ ip_hidden_states (`torch.FloatTensor`, *optional*):
5452
+ Image embeddings.
5453
+ temb (`torch.FloatTensor`, *optional*):
5454
+ Timestep embeddings.
5455
+
5456
+ Returns:
5457
+ `torch.FloatTensor`: Output hidden states.
5458
+ """
5459
+ residual = hidden_states
5460
+
5461
+ batch_size = hidden_states.shape[0]
5462
+
5463
+ # `sample` projections.
5464
+ query = attn.to_q(hidden_states)
5465
+ key = attn.to_k(hidden_states)
5466
+ value = attn.to_v(hidden_states)
5467
+
3901
5468
  inner_dim = key.shape[-1]
3902
5469
  head_dim = inner_dim // attn.heads
3903
5470
 
3904
5471
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3905
-
3906
5472
  key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3907
5473
  value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5474
+ img_query = query
5475
+ img_key = key
5476
+ img_value = value
3908
5477
 
3909
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
3910
- # TODO: add support for attn.scale when we move to Torch 2.1
3911
- hidden_states = F.scaled_dot_product_attention(
3912
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3913
- )
3914
-
3915
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3916
- hidden_states = hidden_states.to(query.dtype)
5478
+ if attn.norm_q is not None:
5479
+ query = attn.norm_q(query)
5480
+ if attn.norm_k is not None:
5481
+ key = attn.norm_k(key)
3917
5482
 
3918
- if ip_adapter_masks is not None:
3919
- if not isinstance(ip_adapter_masks, List):
3920
- # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
3921
- ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
3922
- if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
3923
- raise ValueError(
3924
- f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
3925
- f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
3926
- f"({len(ip_hidden_states)})"
3927
- )
3928
- else:
3929
- for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
3930
- if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
3931
- raise ValueError(
3932
- "Each element of the ip_adapter_masks array should be a tensor with shape "
3933
- "[1, num_images_for_ip_adapter, height, width]."
3934
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
3935
- )
3936
- if mask.shape[1] != ip_state.shape[1]:
3937
- raise ValueError(
3938
- f"Number of masks ({mask.shape[1]}) does not match "
3939
- f"number of ip images ({ip_state.shape[1]}) at index {index}"
3940
- )
3941
- if isinstance(scale, list) and not len(scale) == mask.shape[1]:
3942
- raise ValueError(
3943
- f"Number of masks ({mask.shape[1]}) does not match "
3944
- f"number of scales ({len(scale)}) at index {index}"
3945
- )
3946
- else:
3947
- ip_adapter_masks = [None] * len(self.scale)
5483
+ # `context` projections.
5484
+ if encoder_hidden_states is not None:
5485
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
5486
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
5487
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3948
5488
 
3949
- # for ip-adapter
3950
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
3951
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
3952
- ):
3953
- skip = False
3954
- if isinstance(scale, list):
3955
- if all(s == 0 for s in scale):
3956
- skip = True
3957
- elif scale == 0:
3958
- skip = True
3959
- if not skip:
3960
- if mask is not None:
3961
- if not isinstance(scale, list):
3962
- scale = [scale] * mask.shape[1]
5489
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
5490
+ batch_size, -1, attn.heads, head_dim
5491
+ ).transpose(1, 2)
5492
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
5493
+ batch_size, -1, attn.heads, head_dim
5494
+ ).transpose(1, 2)
5495
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
5496
+ batch_size, -1, attn.heads, head_dim
5497
+ ).transpose(1, 2)
3963
5498
 
3964
- current_num_images = mask.shape[1]
3965
- for i in range(current_num_images):
3966
- ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
3967
- ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
5499
+ if attn.norm_added_q is not None:
5500
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
5501
+ if attn.norm_added_k is not None:
5502
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3968
5503
 
3969
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3970
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5504
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
5505
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
5506
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
3971
5507
 
3972
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
3973
- # TODO: add support for attn.scale when we move to Torch 2.1
3974
- _current_ip_hidden_states = F.scaled_dot_product_attention(
3975
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
3976
- )
5508
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
5509
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
5510
+ hidden_states = hidden_states.to(query.dtype)
3977
5511
 
3978
- _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
3979
- batch_size, -1, attn.heads * head_dim
3980
- )
3981
- _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
5512
+ if encoder_hidden_states is not None:
5513
+ # Split the attention outputs.
5514
+ hidden_states, encoder_hidden_states = (
5515
+ hidden_states[:, : residual.shape[1]],
5516
+ hidden_states[:, residual.shape[1] :],
5517
+ )
5518
+ if not attn.context_pre_only:
5519
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3982
5520
 
3983
- mask_downsample = IPAdapterMaskProcessor.downsample(
3984
- mask[:, i, :, :],
3985
- batch_size,
3986
- _current_ip_hidden_states.shape[1],
3987
- _current_ip_hidden_states.shape[2],
3988
- )
5521
+ # IP Adapter
5522
+ if self.scale != 0 and ip_hidden_states is not None:
5523
+ # Norm image features
5524
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb)
3989
5525
 
3990
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
3991
- hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
3992
- else:
3993
- ip_key = to_k_ip(current_ip_hidden_states)
3994
- ip_value = to_v_ip(current_ip_hidden_states)
5526
+ # To k and v
5527
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
5528
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
3995
5529
 
3996
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3997
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5530
+ # Reshape
5531
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5532
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3998
5533
 
3999
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
4000
- # TODO: add support for attn.scale when we move to Torch 2.1
4001
- current_ip_hidden_states = F.scaled_dot_product_attention(
4002
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
4003
- )
5534
+ # Norm
5535
+ query = self.norm_q(img_query)
5536
+ img_key = self.norm_k(img_key)
5537
+ ip_key = self.norm_ip_k(ip_key)
4004
5538
 
4005
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
4006
- batch_size, -1, attn.heads * head_dim
4007
- )
4008
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
5539
+ # cat img
5540
+ key = torch.cat([img_key, ip_key], dim=2)
5541
+ value = torch.cat([img_value, ip_value], dim=2)
4009
5542
 
4010
- hidden_states = hidden_states + scale * current_ip_hidden_states
5543
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
5544
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
5545
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
5546
+
5547
+ hidden_states = hidden_states + ip_hidden_states * self.scale
4011
5548
 
4012
5549
  # linear proj
4013
5550
  hidden_states = attn.to_out[0](hidden_states)
4014
5551
  # dropout
4015
5552
  hidden_states = attn.to_out[1](hidden_states)
4016
5553
 
4017
- if input_ndim == 4:
4018
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
4019
-
4020
- if attn.residual_connection:
4021
- hidden_states = hidden_states + residual
4022
-
4023
- hidden_states = hidden_states / attn.rescale_output_factor
4024
-
4025
- return hidden_states
5554
+ if encoder_hidden_states is not None:
5555
+ return hidden_states, encoder_hidden_states
5556
+ else:
5557
+ return hidden_states
4026
5558
 
4027
5559
 
4028
5560
  class PAGIdentitySelfAttnProcessor2_0:
@@ -4227,26 +5759,272 @@ class PAGCFGIdentitySelfAttnProcessor2_0:
4227
5759
  return hidden_states
4228
5760
 
4229
5761
 
5762
+ class SanaMultiscaleAttnProcessor2_0:
5763
+ r"""
5764
+ Processor for implementing multiscale quadratic attention.
5765
+ """
5766
+
5767
+ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
5768
+ height, width = hidden_states.shape[-2:]
5769
+ if height * width > attn.attention_head_dim:
5770
+ use_linear_attention = True
5771
+ else:
5772
+ use_linear_attention = False
5773
+
5774
+ residual = hidden_states
5775
+
5776
+ batch_size, _, height, width = list(hidden_states.size())
5777
+ original_dtype = hidden_states.dtype
5778
+
5779
+ hidden_states = hidden_states.movedim(1, -1)
5780
+ query = attn.to_q(hidden_states)
5781
+ key = attn.to_k(hidden_states)
5782
+ value = attn.to_v(hidden_states)
5783
+ hidden_states = torch.cat([query, key, value], dim=3)
5784
+ hidden_states = hidden_states.movedim(-1, 1)
5785
+
5786
+ multi_scale_qkv = [hidden_states]
5787
+ for block in attn.to_qkv_multiscale:
5788
+ multi_scale_qkv.append(block(hidden_states))
5789
+
5790
+ hidden_states = torch.cat(multi_scale_qkv, dim=1)
5791
+
5792
+ if use_linear_attention:
5793
+ # for linear attention upcast hidden_states to float32
5794
+ hidden_states = hidden_states.to(dtype=torch.float32)
5795
+
5796
+ hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
5797
+
5798
+ query, key, value = hidden_states.chunk(3, dim=2)
5799
+ query = attn.nonlinearity(query)
5800
+ key = attn.nonlinearity(key)
5801
+
5802
+ if use_linear_attention:
5803
+ hidden_states = attn.apply_linear_attention(query, key, value)
5804
+ hidden_states = hidden_states.to(dtype=original_dtype)
5805
+ else:
5806
+ hidden_states = attn.apply_quadratic_attention(query, key, value)
5807
+
5808
+ hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
5809
+ hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5810
+
5811
+ if attn.norm_type == "rms_norm":
5812
+ hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5813
+ else:
5814
+ hidden_states = attn.norm_out(hidden_states)
5815
+
5816
+ if attn.residual_connection:
5817
+ hidden_states = hidden_states + residual
5818
+
5819
+ return hidden_states
5820
+
5821
+
4230
5822
  class LoRAAttnProcessor:
5823
+ r"""
5824
+ Processor for implementing attention with LoRA.
5825
+ """
5826
+
4231
5827
  def __init__(self):
4232
5828
  pass
4233
5829
 
4234
5830
 
4235
5831
  class LoRAAttnProcessor2_0:
5832
+ r"""
5833
+ Processor for implementing attention with LoRA (enabled by default if you're using PyTorch 2.0).
5834
+ """
5835
+
4236
5836
  def __init__(self):
4237
5837
  pass
4238
5838
 
4239
5839
 
4240
5840
  class LoRAXFormersAttnProcessor:
5841
+ r"""
5842
+ Processor for implementing attention with LoRA using xFormers.
5843
+ """
5844
+
4241
5845
  def __init__(self):
4242
5846
  pass
4243
5847
 
4244
5848
 
4245
5849
  class LoRAAttnAddedKVProcessor:
5850
+ r"""
5851
+ Processor for implementing attention with LoRA with extra learnable key and value matrices for the text encoder.
5852
+ """
5853
+
4246
5854
  def __init__(self):
4247
5855
  pass
4248
5856
 
4249
5857
 
5858
+ class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
5859
+ r"""
5860
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
5861
+ """
5862
+
5863
+ def __init__(self):
5864
+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
5865
+ deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
5866
+ super().__init__()
5867
+
5868
+
5869
+ class SanaLinearAttnProcessor2_0:
5870
+ r"""
5871
+ Processor for implementing scaled dot-product linear attention.
5872
+ """
5873
+
5874
+ def __call__(
5875
+ self,
5876
+ attn: Attention,
5877
+ hidden_states: torch.Tensor,
5878
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5879
+ attention_mask: Optional[torch.Tensor] = None,
5880
+ ) -> torch.Tensor:
5881
+ original_dtype = hidden_states.dtype
5882
+
5883
+ if encoder_hidden_states is None:
5884
+ encoder_hidden_states = hidden_states
5885
+
5886
+ query = attn.to_q(hidden_states)
5887
+ key = attn.to_k(encoder_hidden_states)
5888
+ value = attn.to_v(encoder_hidden_states)
5889
+
5890
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5891
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5892
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5893
+
5894
+ query = F.relu(query)
5895
+ key = F.relu(key)
5896
+
5897
+ query, key, value = query.float(), key.float(), value.float()
5898
+
5899
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5900
+ scores = torch.matmul(value, key)
5901
+ hidden_states = torch.matmul(scores, query)
5902
+
5903
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + 1e-15)
5904
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
5905
+ hidden_states = hidden_states.to(original_dtype)
5906
+
5907
+ hidden_states = attn.to_out[0](hidden_states)
5908
+ hidden_states = attn.to_out[1](hidden_states)
5909
+
5910
+ if original_dtype == torch.float16:
5911
+ hidden_states = hidden_states.clip(-65504, 65504)
5912
+
5913
+ return hidden_states
5914
+
5915
+
5916
+ class PAGCFGSanaLinearAttnProcessor2_0:
5917
+ r"""
5918
+ Processor for implementing scaled dot-product linear attention.
5919
+ """
5920
+
5921
+ def __call__(
5922
+ self,
5923
+ attn: Attention,
5924
+ hidden_states: torch.Tensor,
5925
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5926
+ attention_mask: Optional[torch.Tensor] = None,
5927
+ ) -> torch.Tensor:
5928
+ original_dtype = hidden_states.dtype
5929
+
5930
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
5931
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
5932
+
5933
+ query = attn.to_q(hidden_states_org)
5934
+ key = attn.to_k(hidden_states_org)
5935
+ value = attn.to_v(hidden_states_org)
5936
+
5937
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5938
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5939
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5940
+
5941
+ query = F.relu(query)
5942
+ key = F.relu(key)
5943
+
5944
+ query, key, value = query.float(), key.float(), value.float()
5945
+
5946
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
5947
+ scores = torch.matmul(value, key)
5948
+ hidden_states_org = torch.matmul(scores, query)
5949
+
5950
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
5951
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
5952
+ hidden_states_org = hidden_states_org.to(original_dtype)
5953
+
5954
+ hidden_states_org = attn.to_out[0](hidden_states_org)
5955
+ hidden_states_org = attn.to_out[1](hidden_states_org)
5956
+
5957
+ # perturbed path (identity attention)
5958
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
5959
+
5960
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
5961
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
5962
+
5963
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
5964
+
5965
+ if original_dtype == torch.float16:
5966
+ hidden_states = hidden_states.clip(-65504, 65504)
5967
+
5968
+ return hidden_states
5969
+
5970
+
5971
+ class PAGIdentitySanaLinearAttnProcessor2_0:
5972
+ r"""
5973
+ Processor for implementing scaled dot-product linear attention.
5974
+ """
5975
+
5976
+ def __call__(
5977
+ self,
5978
+ attn: Attention,
5979
+ hidden_states: torch.Tensor,
5980
+ encoder_hidden_states: Optional[torch.Tensor] = None,
5981
+ attention_mask: Optional[torch.Tensor] = None,
5982
+ ) -> torch.Tensor:
5983
+ original_dtype = hidden_states.dtype
5984
+
5985
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
5986
+
5987
+ query = attn.to_q(hidden_states_org)
5988
+ key = attn.to_k(hidden_states_org)
5989
+ value = attn.to_v(hidden_states_org)
5990
+
5991
+ query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5992
+ key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5993
+ value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
5994
+
5995
+ query = F.relu(query)
5996
+ key = F.relu(key)
5997
+
5998
+ query, key, value = query.float(), key.float(), value.float()
5999
+
6000
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1.0)
6001
+ scores = torch.matmul(value, key)
6002
+ hidden_states_org = torch.matmul(scores, query)
6003
+
6004
+ if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
6005
+ hidden_states_org = hidden_states_org.float()
6006
+
6007
+ hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + 1e-15)
6008
+ hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
6009
+ hidden_states_org = hidden_states_org.to(original_dtype)
6010
+
6011
+ hidden_states_org = attn.to_out[0](hidden_states_org)
6012
+ hidden_states_org = attn.to_out[1](hidden_states_org)
6013
+
6014
+ # perturbed path (identity attention)
6015
+ hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
6016
+
6017
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
6018
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
6019
+
6020
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
6021
+
6022
+ if original_dtype == torch.float16:
6023
+ hidden_states = hidden_states.clip(-65504, 65504)
6024
+
6025
+ return hidden_states
6026
+
6027
+
4250
6028
  ADDED_KV_ATTENTION_PROCESSORS = (
4251
6029
  AttnAddedKVProcessor,
4252
6030
  SlicedAttnAddedKVProcessor,
@@ -4261,23 +6039,59 @@ CROSS_ATTENTION_PROCESSORS = (
4261
6039
  SlicedAttnProcessor,
4262
6040
  IPAdapterAttnProcessor,
4263
6041
  IPAdapterAttnProcessor2_0,
6042
+ FluxIPAdapterJointAttnProcessor2_0,
4264
6043
  )
4265
6044
 
4266
6045
  AttentionProcessor = Union[
4267
6046
  AttnProcessor,
4268
- AttnProcessor2_0,
4269
- FusedAttnProcessor2_0,
4270
- XFormersAttnProcessor,
4271
- SlicedAttnProcessor,
6047
+ CustomDiffusionAttnProcessor,
4272
6048
  AttnAddedKVProcessor,
4273
- SlicedAttnAddedKVProcessor,
4274
6049
  AttnAddedKVProcessor2_0,
6050
+ JointAttnProcessor2_0,
6051
+ PAGJointAttnProcessor2_0,
6052
+ PAGCFGJointAttnProcessor2_0,
6053
+ FusedJointAttnProcessor2_0,
6054
+ AllegroAttnProcessor2_0,
6055
+ AuraFlowAttnProcessor2_0,
6056
+ FusedAuraFlowAttnProcessor2_0,
6057
+ FluxAttnProcessor2_0,
6058
+ FluxAttnProcessor2_0_NPU,
6059
+ FusedFluxAttnProcessor2_0,
6060
+ FusedFluxAttnProcessor2_0_NPU,
6061
+ CogVideoXAttnProcessor2_0,
6062
+ FusedCogVideoXAttnProcessor2_0,
4275
6063
  XFormersAttnAddedKVProcessor,
4276
- CustomDiffusionAttnProcessor,
6064
+ XFormersAttnProcessor,
6065
+ XLAFlashAttnProcessor2_0,
6066
+ AttnProcessorNPU,
6067
+ AttnProcessor2_0,
6068
+ MochiVaeAttnProcessor2_0,
6069
+ MochiAttnProcessor2_0,
6070
+ StableAudioAttnProcessor2_0,
6071
+ HunyuanAttnProcessor2_0,
6072
+ FusedHunyuanAttnProcessor2_0,
6073
+ PAGHunyuanAttnProcessor2_0,
6074
+ PAGCFGHunyuanAttnProcessor2_0,
6075
+ LuminaAttnProcessor2_0,
6076
+ FusedAttnProcessor2_0,
4277
6077
  CustomDiffusionXFormersAttnProcessor,
4278
6078
  CustomDiffusionAttnProcessor2_0,
4279
- PAGCFGIdentitySelfAttnProcessor2_0,
6079
+ SlicedAttnProcessor,
6080
+ SlicedAttnAddedKVProcessor,
6081
+ SanaLinearAttnProcessor2_0,
6082
+ PAGCFGSanaLinearAttnProcessor2_0,
6083
+ PAGIdentitySanaLinearAttnProcessor2_0,
6084
+ SanaMultiscaleLinearAttention,
6085
+ SanaMultiscaleAttnProcessor2_0,
6086
+ SanaMultiscaleAttentionProjection,
6087
+ IPAdapterAttnProcessor,
6088
+ IPAdapterAttnProcessor2_0,
6089
+ IPAdapterXFormersAttnProcessor,
6090
+ SD3IPAdapterJointAttnProcessor2_0,
4280
6091
  PAGIdentitySelfAttnProcessor2_0,
4281
- PAGCFGHunyuanAttnProcessor2_0,
4282
- PAGHunyuanAttnProcessor2_0,
6092
+ PAGCFGIdentitySelfAttnProcessor2_0,
6093
+ LoRAAttnProcessor,
6094
+ LoRAAttnProcessor2_0,
6095
+ LoRAXFormersAttnProcessor,
6096
+ LoRAAttnAddedKVProcessor,
4283
6097
  ]