diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
14
14
  import copy
15
15
  from typing import TYPE_CHECKING, Dict, List, Union
16
16
 
17
+ from torch import nn
18
+
17
19
  from ..utils import logging
18
20
 
19
21
 
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
52
54
  weight_for_adapter,
53
55
  blocks_with_transformer,
54
56
  transformer_per_block,
55
- unet.state_dict(),
57
+ model=unet,
56
58
  default_scale=default_scale,
57
59
  )
58
60
  for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
65
67
  scales: Union[float, Dict],
66
68
  blocks_with_transformer: Dict[str, int],
67
69
  transformer_per_block: Dict[str, int],
68
- state_dict: None,
70
+ model: nn.Module,
69
71
  default_scale: float = 1.0,
70
72
  ):
71
73
  """
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
154
156
 
155
157
  del scales[updown]
156
158
 
159
+ state_dict = model.state_dict()
157
160
  for layer in scales.keys():
158
161
  if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159
162
  raise ValueError(
@@ -26,6 +26,7 @@ _import_structure = {}
26
26
 
27
27
  if is_torch_available():
28
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
+ _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
29
30
  _import_structure["auto_model"] = ["AutoModel"]
30
31
  _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
31
32
  _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -37,6 +38,7 @@ if is_torch_available():
37
38
  _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
38
39
  _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
39
40
  _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
41
+ _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
40
42
  _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
41
43
  _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
42
44
  _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
@@ -87,7 +89,9 @@ if is_torch_available():
87
89
  _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
88
90
  _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
89
91
  _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
92
+ _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
90
93
  _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
94
+ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
91
95
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
92
96
  _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
93
97
  _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
@@ -111,6 +115,7 @@ if is_flax_available():
111
115
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
112
116
  if is_torch_available():
113
117
  from .adapter import MultiAdapter, T2IAdapter
118
+ from .attention_dispatch import AttentionBackendName, attention_backend
114
119
  from .auto_model import AutoModel
115
120
  from .autoencoders import (
116
121
  AsymmetricAutoencoderKL,
@@ -123,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
123
128
  AutoencoderKLLTXVideo,
124
129
  AutoencoderKLMagvit,
125
130
  AutoencoderKLMochi,
131
+ AutoencoderKLQwenImage,
126
132
  AutoencoderKLTemporalDecoder,
127
133
  AutoencoderKLWan,
128
134
  AutoencoderOobleck,
@@ -174,8 +180,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
174
180
  OmniGenTransformer2DModel,
175
181
  PixArtTransformer2DModel,
176
182
  PriorTransformer,
183
+ QwenImageTransformer2DModel,
177
184
  SanaTransformer2DModel,
178
185
  SD3Transformer2DModel,
186
+ SkyReelsV2Transformer3DModel,
179
187
  StableAudioDiTModel,
180
188
  T5FilmDecoder,
181
189
  Transformer2DModel,
@@ -11,23 +11,504 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, List, Optional, Tuple
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
15
16
 
16
17
  import torch
18
+ import torch.nn as nn
17
19
  import torch.nn.functional as F
18
- from torch import nn
19
20
 
20
21
  from ..utils import deprecate, logging
22
+ from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
21
23
  from ..utils.torch_utils import maybe_allow_in_graph
22
24
  from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
23
- from .attention_processor import Attention, JointAttnProcessor2_0
25
+ from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
24
26
  from .embeddings import SinusoidalPositionalEmbedding
25
27
  from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
28
 
27
29
 
30
+ if is_xformers_available():
31
+ import xformers as xops
32
+ else:
33
+ xops = None
34
+
35
+
28
36
  logger = logging.get_logger(__name__)
29
37
 
30
38
 
39
+ class AttentionMixin:
40
+ @property
41
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
42
+ r"""
43
+ Returns:
44
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
45
+ indexed by its weight name.
46
+ """
47
+ # set recursively
48
+ processors = {}
49
+
50
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
51
+ if hasattr(module, "get_processor"):
52
+ processors[f"{name}.processor"] = module.get_processor()
53
+
54
+ for sub_name, child in module.named_children():
55
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
56
+
57
+ return processors
58
+
59
+ for name, module in self.named_children():
60
+ fn_recursive_add_processors(name, module, processors)
61
+
62
+ return processors
63
+
64
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
65
+ r"""
66
+ Sets the attention processor to use to compute attention.
67
+
68
+ Parameters:
69
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
70
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
71
+ for **all** `Attention` layers.
72
+
73
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
74
+ processor. This is strongly recommended when setting trainable attention processors.
75
+
76
+ """
77
+ count = len(self.attn_processors.keys())
78
+
79
+ if isinstance(processor, dict) and len(processor) != count:
80
+ raise ValueError(
81
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
82
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
83
+ )
84
+
85
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
86
+ if hasattr(module, "set_processor"):
87
+ if not isinstance(processor, dict):
88
+ module.set_processor(processor)
89
+ else:
90
+ module.set_processor(processor.pop(f"{name}.processor"))
91
+
92
+ for sub_name, child in module.named_children():
93
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
94
+
95
+ for name, module in self.named_children():
96
+ fn_recursive_attn_processor(name, module, processor)
97
+
98
+ def fuse_qkv_projections(self):
99
+ """
100
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
101
+ are fused. For cross-attention modules, key and value projection matrices are fused.
102
+ """
103
+ for _, attn_processor in self.attn_processors.items():
104
+ if "Added" in str(attn_processor.__class__.__name__):
105
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
106
+
107
+ for module in self.modules():
108
+ if isinstance(module, AttentionModuleMixin):
109
+ module.fuse_projections()
110
+
111
+ def unfuse_qkv_projections(self):
112
+ """Disables the fused QKV projection if enabled.
113
+
114
+ <Tip warning={true}>
115
+
116
+ This API is 🧪 experimental.
117
+
118
+ </Tip>
119
+ """
120
+ for module in self.modules():
121
+ if isinstance(module, AttentionModuleMixin):
122
+ module.unfuse_projections()
123
+
124
+
125
+ class AttentionModuleMixin:
126
+ _default_processor_cls = None
127
+ _available_processors = []
128
+ fused_projections = False
129
+
130
+ def set_processor(self, processor: AttentionProcessor) -> None:
131
+ """
132
+ Set the attention processor to use.
133
+
134
+ Args:
135
+ processor (`AttnProcessor`):
136
+ The attention processor to use.
137
+ """
138
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
139
+ # pop `processor` from `self._modules`
140
+ if (
141
+ hasattr(self, "processor")
142
+ and isinstance(self.processor, torch.nn.Module)
143
+ and not isinstance(processor, torch.nn.Module)
144
+ ):
145
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
146
+ self._modules.pop("processor")
147
+
148
+ self.processor = processor
149
+
150
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
151
+ """
152
+ Get the attention processor in use.
153
+
154
+ Args:
155
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
156
+ Set to `True` to return the deprecated LoRA attention processor.
157
+
158
+ Returns:
159
+ "AttentionProcessor": The attention processor in use.
160
+ """
161
+ if not return_deprecated_lora:
162
+ return self.processor
163
+
164
+ def set_attention_backend(self, backend: str):
165
+ from .attention_dispatch import AttentionBackendName
166
+
167
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
168
+ if backend not in available_backends:
169
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
170
+
171
+ backend = AttentionBackendName(backend.lower())
172
+ self.processor._attention_backend = backend
173
+
174
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
175
+ """
176
+ Set whether to use NPU flash attention from `torch_npu` or not.
177
+
178
+ Args:
179
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
180
+ """
181
+
182
+ if use_npu_flash_attention:
183
+ if not is_torch_npu_available():
184
+ raise ImportError("torch_npu is not available")
185
+
186
+ self.set_attention_backend("_native_npu")
187
+
188
+ def set_use_xla_flash_attention(
189
+ self,
190
+ use_xla_flash_attention: bool,
191
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
192
+ is_flux=False,
193
+ ) -> None:
194
+ """
195
+ Set whether to use XLA flash attention from `torch_xla` or not.
196
+
197
+ Args:
198
+ use_xla_flash_attention (`bool`):
199
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
200
+ partition_spec (`Tuple[]`, *optional*):
201
+ Specify the partition specification if using SPMD. Otherwise None.
202
+ is_flux (`bool`, *optional*, defaults to `False`):
203
+ Whether the model is a Flux model.
204
+ """
205
+ if use_xla_flash_attention:
206
+ if not is_torch_xla_available():
207
+ raise ImportError("torch_xla is not available")
208
+
209
+ self.set_attention_backend("_native_xla")
210
+
211
+ def set_use_memory_efficient_attention_xformers(
212
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
213
+ ) -> None:
214
+ """
215
+ Set whether to use memory efficient attention from `xformers` or not.
216
+
217
+ Args:
218
+ use_memory_efficient_attention_xformers (`bool`):
219
+ Whether to use memory efficient attention from `xformers` or not.
220
+ attention_op (`Callable`, *optional*):
221
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
222
+ `xformers`.
223
+ """
224
+ if use_memory_efficient_attention_xformers:
225
+ if not is_xformers_available():
226
+ raise ModuleNotFoundError(
227
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
228
+ name="xformers",
229
+ )
230
+ elif not torch.cuda.is_available():
231
+ raise ValueError(
232
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
233
+ " only available for GPU "
234
+ )
235
+ else:
236
+ try:
237
+ # Make sure we can run the memory efficient attention
238
+ if is_xformers_available():
239
+ dtype = None
240
+ if attention_op is not None:
241
+ op_fw, op_bw = attention_op
242
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
243
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
244
+ _ = xops.memory_efficient_attention(q, q, q)
245
+ except Exception as e:
246
+ raise e
247
+
248
+ self.set_attention_backend("xformers")
249
+
250
+ @torch.no_grad()
251
+ def fuse_projections(self):
252
+ """
253
+ Fuse the query, key, and value projections into a single projection for efficiency.
254
+ """
255
+ # Skip if already fused
256
+ if getattr(self, "fused_projections", False):
257
+ return
258
+
259
+ device = self.to_q.weight.data.device
260
+ dtype = self.to_q.weight.data.dtype
261
+
262
+ if hasattr(self, "is_cross_attention") and self.is_cross_attention:
263
+ # Fuse cross-attention key-value projections
264
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
265
+ in_features = concatenated_weights.shape[1]
266
+ out_features = concatenated_weights.shape[0]
267
+
268
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
269
+ self.to_kv.weight.copy_(concatenated_weights)
270
+ if hasattr(self, "use_bias") and self.use_bias:
271
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
272
+ self.to_kv.bias.copy_(concatenated_bias)
273
+ else:
274
+ # Fuse self-attention projections
275
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
276
+ in_features = concatenated_weights.shape[1]
277
+ out_features = concatenated_weights.shape[0]
278
+
279
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
280
+ self.to_qkv.weight.copy_(concatenated_weights)
281
+ if hasattr(self, "use_bias") and self.use_bias:
282
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
283
+ self.to_qkv.bias.copy_(concatenated_bias)
284
+
285
+ # Handle added projections for models like SD3, Flux, etc.
286
+ if (
287
+ getattr(self, "add_q_proj", None) is not None
288
+ and getattr(self, "add_k_proj", None) is not None
289
+ and getattr(self, "add_v_proj", None) is not None
290
+ ):
291
+ concatenated_weights = torch.cat(
292
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
293
+ )
294
+ in_features = concatenated_weights.shape[1]
295
+ out_features = concatenated_weights.shape[0]
296
+
297
+ self.to_added_qkv = nn.Linear(
298
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
299
+ )
300
+ self.to_added_qkv.weight.copy_(concatenated_weights)
301
+ if self.added_proj_bias:
302
+ concatenated_bias = torch.cat(
303
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
304
+ )
305
+ self.to_added_qkv.bias.copy_(concatenated_bias)
306
+
307
+ self.fused_projections = True
308
+
309
+ @torch.no_grad()
310
+ def unfuse_projections(self):
311
+ """
312
+ Unfuse the query, key, and value projections back to separate projections.
313
+ """
314
+ # Skip if not fused
315
+ if not getattr(self, "fused_projections", False):
316
+ return
317
+
318
+ # Remove fused projection layers
319
+ if hasattr(self, "to_qkv"):
320
+ delattr(self, "to_qkv")
321
+
322
+ if hasattr(self, "to_kv"):
323
+ delattr(self, "to_kv")
324
+
325
+ if hasattr(self, "to_added_qkv"):
326
+ delattr(self, "to_added_qkv")
327
+
328
+ self.fused_projections = False
329
+
330
+ def set_attention_slice(self, slice_size: int) -> None:
331
+ """
332
+ Set the slice size for attention computation.
333
+
334
+ Args:
335
+ slice_size (`int`):
336
+ The slice size for attention computation.
337
+ """
338
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
339
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
340
+
341
+ processor = None
342
+
343
+ # Try to get a compatible processor for sliced attention
344
+ if slice_size is not None:
345
+ processor = self._get_compatible_processor("sliced")
346
+
347
+ # If no processor was found or slice_size is None, use default processor
348
+ if processor is None:
349
+ processor = self.default_processor_cls()
350
+
351
+ self.set_processor(processor)
352
+
353
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
354
+ """
355
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
356
+
357
+ Args:
358
+ tensor (`torch.Tensor`): The tensor to reshape.
359
+
360
+ Returns:
361
+ `torch.Tensor`: The reshaped tensor.
362
+ """
363
+ head_size = self.heads
364
+ batch_size, seq_len, dim = tensor.shape
365
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
366
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
367
+ return tensor
368
+
369
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
370
+ """
371
+ Reshape the tensor for multi-head attention processing.
372
+
373
+ Args:
374
+ tensor (`torch.Tensor`): The tensor to reshape.
375
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
376
+
377
+ Returns:
378
+ `torch.Tensor`: The reshaped tensor.
379
+ """
380
+ head_size = self.heads
381
+ if tensor.ndim == 3:
382
+ batch_size, seq_len, dim = tensor.shape
383
+ extra_dim = 1
384
+ else:
385
+ batch_size, extra_dim, seq_len, dim = tensor.shape
386
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
387
+ tensor = tensor.permute(0, 2, 1, 3)
388
+
389
+ if out_dim == 3:
390
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
391
+
392
+ return tensor
393
+
394
+ def get_attention_scores(
395
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
396
+ ) -> torch.Tensor:
397
+ """
398
+ Compute the attention scores.
399
+
400
+ Args:
401
+ query (`torch.Tensor`): The query tensor.
402
+ key (`torch.Tensor`): The key tensor.
403
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
404
+
405
+ Returns:
406
+ `torch.Tensor`: The attention probabilities/scores.
407
+ """
408
+ dtype = query.dtype
409
+ if self.upcast_attention:
410
+ query = query.float()
411
+ key = key.float()
412
+
413
+ if attention_mask is None:
414
+ baddbmm_input = torch.empty(
415
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
416
+ )
417
+ beta = 0
418
+ else:
419
+ baddbmm_input = attention_mask
420
+ beta = 1
421
+
422
+ attention_scores = torch.baddbmm(
423
+ baddbmm_input,
424
+ query,
425
+ key.transpose(-1, -2),
426
+ beta=beta,
427
+ alpha=self.scale,
428
+ )
429
+ del baddbmm_input
430
+
431
+ if self.upcast_softmax:
432
+ attention_scores = attention_scores.float()
433
+
434
+ attention_probs = attention_scores.softmax(dim=-1)
435
+ del attention_scores
436
+
437
+ attention_probs = attention_probs.to(dtype)
438
+
439
+ return attention_probs
440
+
441
+ def prepare_attention_mask(
442
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
443
+ ) -> torch.Tensor:
444
+ """
445
+ Prepare the attention mask for the attention computation.
446
+
447
+ Args:
448
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
449
+ target_length (`int`): The target length of the attention mask.
450
+ batch_size (`int`): The batch size for repeating the attention mask.
451
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
452
+
453
+ Returns:
454
+ `torch.Tensor`: The prepared attention mask.
455
+ """
456
+ head_size = self.heads
457
+ if attention_mask is None:
458
+ return attention_mask
459
+
460
+ current_length: int = attention_mask.shape[-1]
461
+ if current_length != target_length:
462
+ if attention_mask.device.type == "mps":
463
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
464
+ # Instead, we can manually construct the padding tensor.
465
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
466
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
467
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
468
+ else:
469
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
470
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
471
+ # remaining_length: int = target_length - current_length
472
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
473
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
474
+
475
+ if out_dim == 3:
476
+ if attention_mask.shape[0] < batch_size * head_size:
477
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
478
+ elif out_dim == 4:
479
+ attention_mask = attention_mask.unsqueeze(1)
480
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
481
+
482
+ return attention_mask
483
+
484
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
485
+ """
486
+ Normalize the encoder hidden states.
487
+
488
+ Args:
489
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
490
+
491
+ Returns:
492
+ `torch.Tensor`: The normalized encoder hidden states.
493
+ """
494
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
495
+ if isinstance(self.norm_cross, nn.LayerNorm):
496
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
497
+ elif isinstance(self.norm_cross, nn.GroupNorm):
498
+ # Group norm norms along the channels dimension and expects
499
+ # input to be in the shape of (N, C, *). In this case, we want
500
+ # to norm along the hidden dimension, so we need to move
501
+ # (batch_size, sequence_length, hidden_size) ->
502
+ # (batch_size, hidden_size, sequence_length)
503
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
504
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
505
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
506
+ else:
507
+ assert False
508
+
509
+ return encoder_hidden_states
510
+
511
+
31
512
  def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
513
  # "feed_forward_chunk_size" can be used to save memory
33
514
  if hidden_states.shape[chunk_dim] % chunk_size != 0: