diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import inspect
15
+ import math
15
16
  from importlib import import_module
16
- from typing import Callable, Optional, Union
17
+ from typing import Callable, List, Optional, Union
17
18
 
18
19
  import torch
19
20
  import torch.nn.functional as F
@@ -21,13 +22,15 @@ from torch import nn
21
22
 
22
23
  from ..image_processor import IPAdapterMaskProcessor
23
24
  from ..utils import deprecate, logging
24
- from ..utils.import_utils import is_xformers_available
25
+ from ..utils.import_utils import is_torch_npu_available, is_xformers_available
25
26
  from ..utils.torch_utils import maybe_allow_in_graph
26
27
  from .lora import LoRALinearLayer
27
28
 
28
29
 
29
30
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
31
 
32
+ if is_torch_npu_available():
33
+ import torch_npu
31
34
 
32
35
  if is_xformers_available():
33
36
  import xformers
@@ -181,25 +184,22 @@ class Attention(nn.Module):
181
184
  f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
182
185
  )
183
186
 
184
- linear_cls = nn.Linear
185
-
186
- self.linear_cls = linear_cls
187
- self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
187
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
188
188
 
189
189
  if not self.only_cross_attention:
190
190
  # only relevant for the `AddedKVProcessor` classes
191
- self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
192
- self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
191
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
192
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
193
193
  else:
194
194
  self.to_k = None
195
195
  self.to_v = None
196
196
 
197
197
  if self.added_kv_proj_dim is not None:
198
- self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
199
- self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
198
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
199
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
200
200
 
201
201
  self.to_out = nn.ModuleList([])
202
- self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
202
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
203
203
  self.to_out.append(nn.Dropout(dropout))
204
204
 
205
205
  # set attention processor
@@ -212,6 +212,23 @@ class Attention(nn.Module):
212
212
  )
213
213
  self.set_processor(processor)
214
214
 
215
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
216
+ r"""
217
+ Set whether to use npu flash attention from `torch_npu` or not.
218
+
219
+ """
220
+ if use_npu_flash_attention:
221
+ processor = AttnProcessorNPU()
222
+ else:
223
+ # set attention processor
224
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
225
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
226
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
227
+ processor = (
228
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
229
+ )
230
+ self.set_processor(processor)
231
+
215
232
  def set_use_memory_efficient_attention_xformers(
216
233
  self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
217
234
  ) -> None:
@@ -424,7 +441,7 @@ class Attention(nn.Module):
424
441
  # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
425
442
  is_lora_activated.pop("add_k_proj", None)
426
443
  is_lora_activated.pop("add_v_proj", None)
427
- # 2. else it is not posssible that only some layers have LoRA activated
444
+ # 2. else it is not possible that only some layers have LoRA activated
428
445
  if not all(is_lora_activated.values()):
429
446
  raise ValueError(
430
447
  f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
@@ -486,9 +503,9 @@ class Attention(nn.Module):
486
503
 
487
504
  def forward(
488
505
  self,
489
- hidden_states: torch.FloatTensor,
490
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
491
- attention_mask: Optional[torch.FloatTensor] = None,
506
+ hidden_states: torch.Tensor,
507
+ encoder_hidden_states: Optional[torch.Tensor] = None,
508
+ attention_mask: Optional[torch.Tensor] = None,
492
509
  **cross_attention_kwargs,
493
510
  ) -> torch.Tensor:
494
511
  r"""
@@ -706,7 +723,7 @@ class Attention(nn.Module):
706
723
  out_features = concatenated_weights.shape[0]
707
724
 
708
725
  # create a new single projection layer and copy over the weights.
709
- self.to_qkv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
726
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
710
727
  self.to_qkv.weight.copy_(concatenated_weights)
711
728
  if self.use_bias:
712
729
  concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
@@ -717,7 +734,7 @@ class Attention(nn.Module):
717
734
  in_features = concatenated_weights.shape[1]
718
735
  out_features = concatenated_weights.shape[0]
719
736
 
720
- self.to_kv = self.linear_cls(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
737
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
721
738
  self.to_kv.weight.copy_(concatenated_weights)
722
739
  if self.use_bias:
723
740
  concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
@@ -734,10 +751,10 @@ class AttnProcessor:
734
751
  def __call__(
735
752
  self,
736
753
  attn: Attention,
737
- hidden_states: torch.FloatTensor,
738
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
739
- attention_mask: Optional[torch.FloatTensor] = None,
740
- temb: Optional[torch.FloatTensor] = None,
754
+ hidden_states: torch.Tensor,
755
+ encoder_hidden_states: Optional[torch.Tensor] = None,
756
+ attention_mask: Optional[torch.Tensor] = None,
757
+ temb: Optional[torch.Tensor] = None,
741
758
  *args,
742
759
  **kwargs,
743
760
  ) -> torch.Tensor:
@@ -846,9 +863,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
846
863
  def __call__(
847
864
  self,
848
865
  attn: Attention,
849
- hidden_states: torch.FloatTensor,
850
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
851
- attention_mask: Optional[torch.FloatTensor] = None,
866
+ hidden_states: torch.Tensor,
867
+ encoder_hidden_states: Optional[torch.Tensor] = None,
868
+ attention_mask: Optional[torch.Tensor] = None,
852
869
  ) -> torch.Tensor:
853
870
  batch_size, sequence_length, _ = hidden_states.shape
854
871
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -911,9 +928,9 @@ class AttnAddedKVProcessor:
911
928
  def __call__(
912
929
  self,
913
930
  attn: Attention,
914
- hidden_states: torch.FloatTensor,
915
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
916
- attention_mask: Optional[torch.FloatTensor] = None,
931
+ hidden_states: torch.Tensor,
932
+ encoder_hidden_states: Optional[torch.Tensor] = None,
933
+ attention_mask: Optional[torch.Tensor] = None,
917
934
  *args,
918
935
  **kwargs,
919
936
  ) -> torch.Tensor:
@@ -984,9 +1001,9 @@ class AttnAddedKVProcessor2_0:
984
1001
  def __call__(
985
1002
  self,
986
1003
  attn: Attention,
987
- hidden_states: torch.FloatTensor,
988
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
989
- attention_mask: Optional[torch.FloatTensor] = None,
1004
+ hidden_states: torch.Tensor,
1005
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1006
+ attention_mask: Optional[torch.Tensor] = None,
990
1007
  *args,
991
1008
  **kwargs,
992
1009
  ) -> torch.Tensor:
@@ -1063,9 +1080,9 @@ class XFormersAttnAddedKVProcessor:
1063
1080
  def __call__(
1064
1081
  self,
1065
1082
  attn: Attention,
1066
- hidden_states: torch.FloatTensor,
1067
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1068
- attention_mask: Optional[torch.FloatTensor] = None,
1083
+ hidden_states: torch.Tensor,
1084
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1085
+ attention_mask: Optional[torch.Tensor] = None,
1069
1086
  ) -> torch.Tensor:
1070
1087
  residual = hidden_states
1071
1088
  hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
@@ -1134,13 +1151,13 @@ class XFormersAttnProcessor:
1134
1151
  def __call__(
1135
1152
  self,
1136
1153
  attn: Attention,
1137
- hidden_states: torch.FloatTensor,
1138
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1139
- attention_mask: Optional[torch.FloatTensor] = None,
1140
- temb: Optional[torch.FloatTensor] = None,
1154
+ hidden_states: torch.Tensor,
1155
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1156
+ attention_mask: Optional[torch.Tensor] = None,
1157
+ temb: Optional[torch.Tensor] = None,
1141
1158
  *args,
1142
1159
  **kwargs,
1143
- ) -> torch.FloatTensor:
1160
+ ) -> torch.Tensor:
1144
1161
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1145
1162
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1146
1163
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1210,6 +1227,116 @@ class XFormersAttnProcessor:
1210
1227
  return hidden_states
1211
1228
 
1212
1229
 
1230
+ class AttnProcessorNPU:
1231
+
1232
+ r"""
1233
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
1234
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
1235
+ not significant.
1236
+
1237
+ """
1238
+
1239
+ def __init__(self):
1240
+ if not is_torch_npu_available():
1241
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
1242
+
1243
+ def __call__(
1244
+ self,
1245
+ attn: Attention,
1246
+ hidden_states: torch.Tensor,
1247
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1248
+ attention_mask: Optional[torch.Tensor] = None,
1249
+ temb: Optional[torch.Tensor] = None,
1250
+ *args,
1251
+ **kwargs,
1252
+ ) -> torch.Tensor:
1253
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1254
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1255
+ deprecate("scale", "1.0.0", deprecation_message)
1256
+
1257
+ residual = hidden_states
1258
+ if attn.spatial_norm is not None:
1259
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1260
+
1261
+ input_ndim = hidden_states.ndim
1262
+
1263
+ if input_ndim == 4:
1264
+ batch_size, channel, height, width = hidden_states.shape
1265
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1266
+
1267
+ batch_size, sequence_length, _ = (
1268
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1269
+ )
1270
+
1271
+ if attention_mask is not None:
1272
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1273
+ # scaled_dot_product_attention expects attention_mask shape to be
1274
+ # (batch, heads, source_length, target_length)
1275
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1276
+
1277
+ if attn.group_norm is not None:
1278
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1279
+
1280
+ query = attn.to_q(hidden_states)
1281
+
1282
+ if encoder_hidden_states is None:
1283
+ encoder_hidden_states = hidden_states
1284
+ elif attn.norm_cross:
1285
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1286
+
1287
+ key = attn.to_k(encoder_hidden_states)
1288
+ value = attn.to_v(encoder_hidden_states)
1289
+
1290
+ inner_dim = key.shape[-1]
1291
+ head_dim = inner_dim // attn.heads
1292
+
1293
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1294
+
1295
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1296
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1297
+
1298
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1299
+ if query.dtype in (torch.float16, torch.bfloat16):
1300
+ hidden_states = torch_npu.npu_fusion_attention(
1301
+ query,
1302
+ key,
1303
+ value,
1304
+ attn.heads,
1305
+ input_layout="BNSD",
1306
+ pse=None,
1307
+ atten_mask=attention_mask,
1308
+ scale=1.0 / math.sqrt(query.shape[-1]),
1309
+ pre_tockens=65536,
1310
+ next_tockens=65536,
1311
+ keep_prob=1.0,
1312
+ sync=False,
1313
+ inner_precise=0,
1314
+ )[0]
1315
+ else:
1316
+ # TODO: add support for attn.scale when we move to Torch 2.1
1317
+ hidden_states = F.scaled_dot_product_attention(
1318
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1319
+ )
1320
+
1321
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1322
+ hidden_states = hidden_states.to(query.dtype)
1323
+
1324
+ # linear proj
1325
+ hidden_states = attn.to_out[0](hidden_states)
1326
+ # dropout
1327
+ hidden_states = attn.to_out[1](hidden_states)
1328
+
1329
+ if input_ndim == 4:
1330
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1331
+
1332
+ if attn.residual_connection:
1333
+ hidden_states = hidden_states + residual
1334
+
1335
+ hidden_states = hidden_states / attn.rescale_output_factor
1336
+
1337
+ return hidden_states
1338
+
1339
+
1213
1340
  class AttnProcessor2_0:
1214
1341
  r"""
1215
1342
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -1222,13 +1349,13 @@ class AttnProcessor2_0:
1222
1349
  def __call__(
1223
1350
  self,
1224
1351
  attn: Attention,
1225
- hidden_states: torch.FloatTensor,
1226
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1227
- attention_mask: Optional[torch.FloatTensor] = None,
1228
- temb: Optional[torch.FloatTensor] = None,
1352
+ hidden_states: torch.Tensor,
1353
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1354
+ attention_mask: Optional[torch.Tensor] = None,
1355
+ temb: Optional[torch.Tensor] = None,
1229
1356
  *args,
1230
1357
  **kwargs,
1231
- ) -> torch.FloatTensor:
1358
+ ) -> torch.Tensor:
1232
1359
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1233
1360
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1234
1361
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1301,9 +1428,9 @@ class AttnProcessor2_0:
1301
1428
 
1302
1429
  class FusedAttnProcessor2_0:
1303
1430
  r"""
1304
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1305
- It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
1306
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1431
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
1432
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
1433
+ For cross-attention modules, key and value projection matrices are fused.
1307
1434
 
1308
1435
  <Tip warning={true}>
1309
1436
 
@@ -1321,13 +1448,13 @@ class FusedAttnProcessor2_0:
1321
1448
  def __call__(
1322
1449
  self,
1323
1450
  attn: Attention,
1324
- hidden_states: torch.FloatTensor,
1325
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1326
- attention_mask: Optional[torch.FloatTensor] = None,
1327
- temb: Optional[torch.FloatTensor] = None,
1451
+ hidden_states: torch.Tensor,
1452
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1453
+ attention_mask: Optional[torch.Tensor] = None,
1454
+ temb: Optional[torch.Tensor] = None,
1328
1455
  *args,
1329
1456
  **kwargs,
1330
- ) -> torch.FloatTensor:
1457
+ ) -> torch.Tensor:
1331
1458
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1332
1459
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1333
1460
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1454,10 +1581,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1454
1581
  def __call__(
1455
1582
  self,
1456
1583
  attn: Attention,
1457
- hidden_states: torch.FloatTensor,
1458
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1459
- attention_mask: Optional[torch.FloatTensor] = None,
1460
- ) -> torch.FloatTensor:
1584
+ hidden_states: torch.Tensor,
1585
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1586
+ attention_mask: Optional[torch.Tensor] = None,
1587
+ ) -> torch.Tensor:
1461
1588
  batch_size, sequence_length, _ = (
1462
1589
  hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1463
1590
  )
@@ -1565,10 +1692,10 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
1565
1692
  def __call__(
1566
1693
  self,
1567
1694
  attn: Attention,
1568
- hidden_states: torch.FloatTensor,
1569
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1570
- attention_mask: Optional[torch.FloatTensor] = None,
1571
- ) -> torch.FloatTensor:
1695
+ hidden_states: torch.Tensor,
1696
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1697
+ attention_mask: Optional[torch.Tensor] = None,
1698
+ ) -> torch.Tensor:
1572
1699
  batch_size, sequence_length, _ = hidden_states.shape
1573
1700
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1574
1701
  if self.train_q_out:
@@ -1646,10 +1773,10 @@ class SlicedAttnProcessor:
1646
1773
  def __call__(
1647
1774
  self,
1648
1775
  attn: Attention,
1649
- hidden_states: torch.FloatTensor,
1650
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1651
- attention_mask: Optional[torch.FloatTensor] = None,
1652
- ) -> torch.FloatTensor:
1776
+ hidden_states: torch.Tensor,
1777
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1778
+ attention_mask: Optional[torch.Tensor] = None,
1779
+ ) -> torch.Tensor:
1653
1780
  residual = hidden_states
1654
1781
 
1655
1782
  input_ndim = hidden_states.ndim
@@ -1733,11 +1860,11 @@ class SlicedAttnAddedKVProcessor:
1733
1860
  def __call__(
1734
1861
  self,
1735
1862
  attn: "Attention",
1736
- hidden_states: torch.FloatTensor,
1737
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1738
- attention_mask: Optional[torch.FloatTensor] = None,
1739
- temb: Optional[torch.FloatTensor] = None,
1740
- ) -> torch.FloatTensor:
1863
+ hidden_states: torch.Tensor,
1864
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1865
+ attention_mask: Optional[torch.Tensor] = None,
1866
+ temb: Optional[torch.Tensor] = None,
1867
+ ) -> torch.Tensor:
1741
1868
  residual = hidden_states
1742
1869
 
1743
1870
  if attn.spatial_norm is not None:
@@ -1830,7 +1957,7 @@ class SpatialNorm(nn.Module):
1830
1957
  self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1831
1958
  self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1832
1959
 
1833
- def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
1960
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
1834
1961
  f_size = f.shape[-2:]
1835
1962
  zq = F.interpolate(zq, size=f_size, mode="nearest")
1836
1963
  norm_f = self.norm_layer(f)
@@ -1876,7 +2003,7 @@ class LoRAAttnProcessor(nn.Module):
1876
2003
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1877
2004
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1878
2005
 
1879
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2006
+ def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
1880
2007
  self_cls_name = self.__class__.__name__
1881
2008
  deprecate(
1882
2009
  self_cls_name,
@@ -1937,7 +2064,7 @@ class LoRAAttnProcessor2_0(nn.Module):
1937
2064
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
1938
2065
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
1939
2066
 
1940
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2067
+ def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
1941
2068
  self_cls_name = self.__class__.__name__
1942
2069
  deprecate(
1943
2070
  self_cls_name,
@@ -2016,7 +2143,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
2016
2143
  self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
2017
2144
  self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
2018
2145
 
2019
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2146
+ def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2020
2147
  self_cls_name = self.__class__.__name__
2021
2148
  deprecate(
2022
2149
  self_cls_name,
@@ -2075,7 +2202,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
2075
2202
  self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2076
2203
  self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
2077
2204
 
2078
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
2205
+ def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
2079
2206
  self_cls_name = self.__class__.__name__
2080
2207
  deprecate(
2081
2208
  self_cls_name,
@@ -2098,7 +2225,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
2098
2225
 
2099
2226
  class IPAdapterAttnProcessor(nn.Module):
2100
2227
  r"""
2101
- Attention processor for Multiple IP-Adapater.
2228
+ Attention processor for Multiple IP-Adapters.
2102
2229
 
2103
2230
  Args:
2104
2231
  hidden_size (`int`):
@@ -2137,12 +2264,12 @@ class IPAdapterAttnProcessor(nn.Module):
2137
2264
  def __call__(
2138
2265
  self,
2139
2266
  attn: Attention,
2140
- hidden_states: torch.FloatTensor,
2141
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2142
- attention_mask: Optional[torch.FloatTensor] = None,
2143
- temb: Optional[torch.FloatTensor] = None,
2267
+ hidden_states: torch.Tensor,
2268
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2269
+ attention_mask: Optional[torch.Tensor] = None,
2270
+ temb: Optional[torch.Tensor] = None,
2144
2271
  scale: float = 1.0,
2145
- ip_adapter_masks: Optional[torch.FloatTensor] = None,
2272
+ ip_adapter_masks: Optional[torch.Tensor] = None,
2146
2273
  ):
2147
2274
  residual = hidden_states
2148
2275
 
@@ -2152,8 +2279,8 @@ class IPAdapterAttnProcessor(nn.Module):
2152
2279
  encoder_hidden_states, ip_hidden_states = encoder_hidden_states
2153
2280
  else:
2154
2281
  deprecation_message = (
2155
- "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
2156
- " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
2282
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
2283
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
2157
2284
  )
2158
2285
  deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
2159
2286
  end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
@@ -2198,15 +2325,33 @@ class IPAdapterAttnProcessor(nn.Module):
2198
2325
  hidden_states = attn.batch_to_head_dim(hidden_states)
2199
2326
 
2200
2327
  if ip_adapter_masks is not None:
2201
- if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2328
+ if not isinstance(ip_adapter_masks, List):
2329
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2330
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2331
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
2202
2332
  raise ValueError(
2203
- " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2204
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2205
- )
2206
- if len(ip_adapter_masks) != len(self.scale):
2207
- raise ValueError(
2208
- f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2333
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2334
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2335
+ f"({len(ip_hidden_states)})"
2209
2336
  )
2337
+ else:
2338
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
2339
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2340
+ raise ValueError(
2341
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
2342
+ "[1, num_images_for_ip_adapter, height, width]."
2343
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2344
+ )
2345
+ if mask.shape[1] != ip_state.shape[1]:
2346
+ raise ValueError(
2347
+ f"Number of masks ({mask.shape[1]}) does not match "
2348
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
2349
+ )
2350
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2351
+ raise ValueError(
2352
+ f"Number of masks ({mask.shape[1]}) does not match "
2353
+ f"number of scales ({len(scale)}) at index {index}"
2354
+ )
2210
2355
  else:
2211
2356
  ip_adapter_masks = [None] * len(self.scale)
2212
2357
 
@@ -2214,26 +2359,51 @@ class IPAdapterAttnProcessor(nn.Module):
2214
2359
  for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2215
2360
  ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2216
2361
  ):
2217
- ip_key = to_k_ip(current_ip_hidden_states)
2218
- ip_value = to_v_ip(current_ip_hidden_states)
2219
-
2220
- ip_key = attn.head_to_batch_dim(ip_key)
2221
- ip_value = attn.head_to_batch_dim(ip_value)
2222
-
2223
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2224
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2225
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
2226
-
2227
- if mask is not None:
2228
- mask_downsample = IPAdapterMaskProcessor.downsample(
2229
- mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2230
- )
2231
-
2232
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2233
-
2234
- current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2235
-
2236
- hidden_states = hidden_states + scale * current_ip_hidden_states
2362
+ skip = False
2363
+ if isinstance(scale, list):
2364
+ if all(s == 0 for s in scale):
2365
+ skip = True
2366
+ elif scale == 0:
2367
+ skip = True
2368
+ if not skip:
2369
+ if mask is not None:
2370
+ if not isinstance(scale, list):
2371
+ scale = [scale] * mask.shape[1]
2372
+
2373
+ current_num_images = mask.shape[1]
2374
+ for i in range(current_num_images):
2375
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2376
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2377
+
2378
+ ip_key = attn.head_to_batch_dim(ip_key)
2379
+ ip_value = attn.head_to_batch_dim(ip_value)
2380
+
2381
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2382
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2383
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
2384
+
2385
+ mask_downsample = IPAdapterMaskProcessor.downsample(
2386
+ mask[:, i, :, :],
2387
+ batch_size,
2388
+ _current_ip_hidden_states.shape[1],
2389
+ _current_ip_hidden_states.shape[2],
2390
+ )
2391
+
2392
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2393
+
2394
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2395
+ else:
2396
+ ip_key = to_k_ip(current_ip_hidden_states)
2397
+ ip_value = to_v_ip(current_ip_hidden_states)
2398
+
2399
+ ip_key = attn.head_to_batch_dim(ip_key)
2400
+ ip_value = attn.head_to_batch_dim(ip_value)
2401
+
2402
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2403
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2404
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
2405
+
2406
+ hidden_states = hidden_states + scale * current_ip_hidden_states
2237
2407
 
2238
2408
  # linear proj
2239
2409
  hidden_states = attn.to_out[0](hidden_states)
@@ -2253,7 +2423,7 @@ class IPAdapterAttnProcessor(nn.Module):
2253
2423
 
2254
2424
  class IPAdapterAttnProcessor2_0(torch.nn.Module):
2255
2425
  r"""
2256
- Attention processor for IP-Adapater for PyTorch 2.0.
2426
+ Attention processor for IP-Adapter for PyTorch 2.0.
2257
2427
 
2258
2428
  Args:
2259
2429
  hidden_size (`int`):
@@ -2297,12 +2467,12 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2297
2467
  def __call__(
2298
2468
  self,
2299
2469
  attn: Attention,
2300
- hidden_states: torch.FloatTensor,
2301
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2302
- attention_mask: Optional[torch.FloatTensor] = None,
2303
- temb: Optional[torch.FloatTensor] = None,
2470
+ hidden_states: torch.Tensor,
2471
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2472
+ attention_mask: Optional[torch.Tensor] = None,
2473
+ temb: Optional[torch.Tensor] = None,
2304
2474
  scale: float = 1.0,
2305
- ip_adapter_masks: Optional[torch.FloatTensor] = None,
2475
+ ip_adapter_masks: Optional[torch.Tensor] = None,
2306
2476
  ):
2307
2477
  residual = hidden_states
2308
2478
 
@@ -2312,8 +2482,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2312
2482
  encoder_hidden_states, ip_hidden_states = encoder_hidden_states
2313
2483
  else:
2314
2484
  deprecation_message = (
2315
- "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
2316
- " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
2485
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
2486
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
2317
2487
  )
2318
2488
  deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
2319
2489
  end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
@@ -2372,15 +2542,33 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2372
2542
  hidden_states = hidden_states.to(query.dtype)
2373
2543
 
2374
2544
  if ip_adapter_masks is not None:
2375
- if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
2376
- raise ValueError(
2377
- " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
2378
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2379
- )
2380
- if len(ip_adapter_masks) != len(self.scale):
2545
+ if not isinstance(ip_adapter_masks, List):
2546
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
2547
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
2548
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
2381
2549
  raise ValueError(
2382
- f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
2550
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
2551
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
2552
+ f"({len(ip_hidden_states)})"
2383
2553
  )
2554
+ else:
2555
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
2556
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
2557
+ raise ValueError(
2558
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
2559
+ "[1, num_images_for_ip_adapter, height, width]."
2560
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
2561
+ )
2562
+ if mask.shape[1] != ip_state.shape[1]:
2563
+ raise ValueError(
2564
+ f"Number of masks ({mask.shape[1]}) does not match "
2565
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
2566
+ )
2567
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
2568
+ raise ValueError(
2569
+ f"Number of masks ({mask.shape[1]}) does not match "
2570
+ f"number of scales ({len(scale)}) at index {index}"
2571
+ )
2384
2572
  else:
2385
2573
  ip_adapter_masks = [None] * len(self.scale)
2386
2574
 
@@ -2388,33 +2576,64 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2388
2576
  for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
2389
2577
  ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
2390
2578
  ):
2391
- ip_key = to_k_ip(current_ip_hidden_states)
2392
- ip_value = to_v_ip(current_ip_hidden_states)
2393
-
2394
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2395
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2396
-
2397
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
2398
- # TODO: add support for attn.scale when we move to Torch 2.1
2399
- current_ip_hidden_states = F.scaled_dot_product_attention(
2400
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2401
- )
2402
-
2403
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2404
- batch_size, -1, attn.heads * head_dim
2405
- )
2406
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2407
-
2408
- if mask is not None:
2409
- mask_downsample = IPAdapterMaskProcessor.downsample(
2410
- mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
2411
- )
2412
-
2413
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2579
+ skip = False
2580
+ if isinstance(scale, list):
2581
+ if all(s == 0 for s in scale):
2582
+ skip = True
2583
+ elif scale == 0:
2584
+ skip = True
2585
+ if not skip:
2586
+ if mask is not None:
2587
+ if not isinstance(scale, list):
2588
+ scale = [scale] * mask.shape[1]
2589
+
2590
+ current_num_images = mask.shape[1]
2591
+ for i in range(current_num_images):
2592
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
2593
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
2594
+
2595
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2596
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2597
+
2598
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2599
+ # TODO: add support for attn.scale when we move to Torch 2.1
2600
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
2601
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2602
+ )
2603
+
2604
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
2605
+ batch_size, -1, attn.heads * head_dim
2606
+ )
2607
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
2608
+
2609
+ mask_downsample = IPAdapterMaskProcessor.downsample(
2610
+ mask[:, i, :, :],
2611
+ batch_size,
2612
+ _current_ip_hidden_states.shape[1],
2613
+ _current_ip_hidden_states.shape[2],
2614
+ )
2615
+
2616
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
2617
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
2618
+ else:
2619
+ ip_key = to_k_ip(current_ip_hidden_states)
2620
+ ip_value = to_v_ip(current_ip_hidden_states)
2621
+
2622
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2623
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2624
+
2625
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2626
+ # TODO: add support for attn.scale when we move to Torch 2.1
2627
+ current_ip_hidden_states = F.scaled_dot_product_attention(
2628
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2629
+ )
2414
2630
 
2415
- current_ip_hidden_states = current_ip_hidden_states * mask_downsample
2631
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2632
+ batch_size, -1, attn.heads * head_dim
2633
+ )
2634
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
2416
2635
 
2417
- hidden_states = hidden_states + scale * current_ip_hidden_states
2636
+ hidden_states = hidden_states + scale * current_ip_hidden_states
2418
2637
 
2419
2638
  # linear proj
2420
2639
  hidden_states = attn.to_out[0](hidden_states)