diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,8 @@
|
|
14
14
|
import copy
|
15
15
|
from typing import TYPE_CHECKING, Dict, List, Union
|
16
16
|
|
17
|
+
from torch import nn
|
18
|
+
|
17
19
|
from ..utils import logging
|
18
20
|
|
19
21
|
|
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
|
|
52
54
|
weight_for_adapter,
|
53
55
|
blocks_with_transformer,
|
54
56
|
transformer_per_block,
|
55
|
-
unet
|
57
|
+
model=unet,
|
56
58
|
default_scale=default_scale,
|
57
59
|
)
|
58
60
|
for weight_for_adapter in weight_scales
|
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
|
65
67
|
scales: Union[float, Dict],
|
66
68
|
blocks_with_transformer: Dict[str, int],
|
67
69
|
transformer_per_block: Dict[str, int],
|
68
|
-
|
70
|
+
model: nn.Module,
|
69
71
|
default_scale: float = 1.0,
|
70
72
|
):
|
71
73
|
"""
|
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
|
154
156
|
|
155
157
|
del scales[updown]
|
156
158
|
|
159
|
+
state_dict = model.state_dict()
|
157
160
|
for layer in scales.keys():
|
158
161
|
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
159
162
|
raise ValueError(
|
diffusers/models/__init__.py
CHANGED
@@ -26,6 +26,7 @@ _import_structure = {}
|
|
26
26
|
|
27
27
|
if is_torch_available():
|
28
28
|
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
29
|
+
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
29
30
|
_import_structure["auto_model"] = ["AutoModel"]
|
30
31
|
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
31
32
|
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
@@ -37,6 +38,7 @@ if is_torch_available():
|
|
37
38
|
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
38
39
|
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
39
40
|
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
41
|
+
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
40
42
|
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
41
43
|
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
42
44
|
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
@@ -87,7 +89,9 @@ if is_torch_available():
|
|
87
89
|
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
88
90
|
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
89
91
|
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
92
|
+
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
90
93
|
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
94
|
+
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
91
95
|
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
92
96
|
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
93
97
|
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
@@ -111,6 +115,7 @@ if is_flax_available():
|
|
111
115
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
112
116
|
if is_torch_available():
|
113
117
|
from .adapter import MultiAdapter, T2IAdapter
|
118
|
+
from .attention_dispatch import AttentionBackendName, attention_backend
|
114
119
|
from .auto_model import AutoModel
|
115
120
|
from .autoencoders import (
|
116
121
|
AsymmetricAutoencoderKL,
|
@@ -123,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
123
128
|
AutoencoderKLLTXVideo,
|
124
129
|
AutoencoderKLMagvit,
|
125
130
|
AutoencoderKLMochi,
|
131
|
+
AutoencoderKLQwenImage,
|
126
132
|
AutoencoderKLTemporalDecoder,
|
127
133
|
AutoencoderKLWan,
|
128
134
|
AutoencoderOobleck,
|
@@ -174,8 +180,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
174
180
|
OmniGenTransformer2DModel,
|
175
181
|
PixArtTransformer2DModel,
|
176
182
|
PriorTransformer,
|
183
|
+
QwenImageTransformer2DModel,
|
177
184
|
SanaTransformer2DModel,
|
178
185
|
SD3Transformer2DModel,
|
186
|
+
SkyReelsV2Transformer3DModel,
|
179
187
|
StableAudioDiTModel,
|
180
188
|
T5FilmDecoder,
|
181
189
|
Transformer2DModel,
|
diffusers/models/attention.py
CHANGED
@@ -11,23 +11,504 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
14
|
+
|
15
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
15
16
|
|
16
17
|
import torch
|
18
|
+
import torch.nn as nn
|
17
19
|
import torch.nn.functional as F
|
18
|
-
from torch import nn
|
19
20
|
|
20
21
|
from ..utils import deprecate, logging
|
22
|
+
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
|
21
23
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
24
|
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
23
|
-
from .attention_processor import Attention, JointAttnProcessor2_0
|
25
|
+
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
|
24
26
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
27
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
26
28
|
|
27
29
|
|
30
|
+
if is_xformers_available():
|
31
|
+
import xformers as xops
|
32
|
+
else:
|
33
|
+
xops = None
|
34
|
+
|
35
|
+
|
28
36
|
logger = logging.get_logger(__name__)
|
29
37
|
|
30
38
|
|
39
|
+
class AttentionMixin:
|
40
|
+
@property
|
41
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
42
|
+
r"""
|
43
|
+
Returns:
|
44
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
45
|
+
indexed by its weight name.
|
46
|
+
"""
|
47
|
+
# set recursively
|
48
|
+
processors = {}
|
49
|
+
|
50
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
51
|
+
if hasattr(module, "get_processor"):
|
52
|
+
processors[f"{name}.processor"] = module.get_processor()
|
53
|
+
|
54
|
+
for sub_name, child in module.named_children():
|
55
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
56
|
+
|
57
|
+
return processors
|
58
|
+
|
59
|
+
for name, module in self.named_children():
|
60
|
+
fn_recursive_add_processors(name, module, processors)
|
61
|
+
|
62
|
+
return processors
|
63
|
+
|
64
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
65
|
+
r"""
|
66
|
+
Sets the attention processor to use to compute attention.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
70
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
71
|
+
for **all** `Attention` layers.
|
72
|
+
|
73
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
74
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
75
|
+
|
76
|
+
"""
|
77
|
+
count = len(self.attn_processors.keys())
|
78
|
+
|
79
|
+
if isinstance(processor, dict) and len(processor) != count:
|
80
|
+
raise ValueError(
|
81
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
82
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
83
|
+
)
|
84
|
+
|
85
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
86
|
+
if hasattr(module, "set_processor"):
|
87
|
+
if not isinstance(processor, dict):
|
88
|
+
module.set_processor(processor)
|
89
|
+
else:
|
90
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
91
|
+
|
92
|
+
for sub_name, child in module.named_children():
|
93
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
94
|
+
|
95
|
+
for name, module in self.named_children():
|
96
|
+
fn_recursive_attn_processor(name, module, processor)
|
97
|
+
|
98
|
+
def fuse_qkv_projections(self):
|
99
|
+
"""
|
100
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
101
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
102
|
+
"""
|
103
|
+
for _, attn_processor in self.attn_processors.items():
|
104
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
105
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
106
|
+
|
107
|
+
for module in self.modules():
|
108
|
+
if isinstance(module, AttentionModuleMixin):
|
109
|
+
module.fuse_projections()
|
110
|
+
|
111
|
+
def unfuse_qkv_projections(self):
|
112
|
+
"""Disables the fused QKV projection if enabled.
|
113
|
+
|
114
|
+
<Tip warning={true}>
|
115
|
+
|
116
|
+
This API is 🧪 experimental.
|
117
|
+
|
118
|
+
</Tip>
|
119
|
+
"""
|
120
|
+
for module in self.modules():
|
121
|
+
if isinstance(module, AttentionModuleMixin):
|
122
|
+
module.unfuse_projections()
|
123
|
+
|
124
|
+
|
125
|
+
class AttentionModuleMixin:
|
126
|
+
_default_processor_cls = None
|
127
|
+
_available_processors = []
|
128
|
+
fused_projections = False
|
129
|
+
|
130
|
+
def set_processor(self, processor: AttentionProcessor) -> None:
|
131
|
+
"""
|
132
|
+
Set the attention processor to use.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
processor (`AttnProcessor`):
|
136
|
+
The attention processor to use.
|
137
|
+
"""
|
138
|
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
139
|
+
# pop `processor` from `self._modules`
|
140
|
+
if (
|
141
|
+
hasattr(self, "processor")
|
142
|
+
and isinstance(self.processor, torch.nn.Module)
|
143
|
+
and not isinstance(processor, torch.nn.Module)
|
144
|
+
):
|
145
|
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
146
|
+
self._modules.pop("processor")
|
147
|
+
|
148
|
+
self.processor = processor
|
149
|
+
|
150
|
+
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
151
|
+
"""
|
152
|
+
Get the attention processor in use.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
156
|
+
Set to `True` to return the deprecated LoRA attention processor.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
"AttentionProcessor": The attention processor in use.
|
160
|
+
"""
|
161
|
+
if not return_deprecated_lora:
|
162
|
+
return self.processor
|
163
|
+
|
164
|
+
def set_attention_backend(self, backend: str):
|
165
|
+
from .attention_dispatch import AttentionBackendName
|
166
|
+
|
167
|
+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
168
|
+
if backend not in available_backends:
|
169
|
+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
170
|
+
|
171
|
+
backend = AttentionBackendName(backend.lower())
|
172
|
+
self.processor._attention_backend = backend
|
173
|
+
|
174
|
+
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
175
|
+
"""
|
176
|
+
Set whether to use NPU flash attention from `torch_npu` or not.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
180
|
+
"""
|
181
|
+
|
182
|
+
if use_npu_flash_attention:
|
183
|
+
if not is_torch_npu_available():
|
184
|
+
raise ImportError("torch_npu is not available")
|
185
|
+
|
186
|
+
self.set_attention_backend("_native_npu")
|
187
|
+
|
188
|
+
def set_use_xla_flash_attention(
|
189
|
+
self,
|
190
|
+
use_xla_flash_attention: bool,
|
191
|
+
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
192
|
+
is_flux=False,
|
193
|
+
) -> None:
|
194
|
+
"""
|
195
|
+
Set whether to use XLA flash attention from `torch_xla` or not.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
use_xla_flash_attention (`bool`):
|
199
|
+
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
200
|
+
partition_spec (`Tuple[]`, *optional*):
|
201
|
+
Specify the partition specification if using SPMD. Otherwise None.
|
202
|
+
is_flux (`bool`, *optional*, defaults to `False`):
|
203
|
+
Whether the model is a Flux model.
|
204
|
+
"""
|
205
|
+
if use_xla_flash_attention:
|
206
|
+
if not is_torch_xla_available():
|
207
|
+
raise ImportError("torch_xla is not available")
|
208
|
+
|
209
|
+
self.set_attention_backend("_native_xla")
|
210
|
+
|
211
|
+
def set_use_memory_efficient_attention_xformers(
|
212
|
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
213
|
+
) -> None:
|
214
|
+
"""
|
215
|
+
Set whether to use memory efficient attention from `xformers` or not.
|
216
|
+
|
217
|
+
Args:
|
218
|
+
use_memory_efficient_attention_xformers (`bool`):
|
219
|
+
Whether to use memory efficient attention from `xformers` or not.
|
220
|
+
attention_op (`Callable`, *optional*):
|
221
|
+
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
222
|
+
`xformers`.
|
223
|
+
"""
|
224
|
+
if use_memory_efficient_attention_xformers:
|
225
|
+
if not is_xformers_available():
|
226
|
+
raise ModuleNotFoundError(
|
227
|
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
|
228
|
+
name="xformers",
|
229
|
+
)
|
230
|
+
elif not torch.cuda.is_available():
|
231
|
+
raise ValueError(
|
232
|
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
233
|
+
" only available for GPU "
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
try:
|
237
|
+
# Make sure we can run the memory efficient attention
|
238
|
+
if is_xformers_available():
|
239
|
+
dtype = None
|
240
|
+
if attention_op is not None:
|
241
|
+
op_fw, op_bw = attention_op
|
242
|
+
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
243
|
+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
244
|
+
_ = xops.memory_efficient_attention(q, q, q)
|
245
|
+
except Exception as e:
|
246
|
+
raise e
|
247
|
+
|
248
|
+
self.set_attention_backend("xformers")
|
249
|
+
|
250
|
+
@torch.no_grad()
|
251
|
+
def fuse_projections(self):
|
252
|
+
"""
|
253
|
+
Fuse the query, key, and value projections into a single projection for efficiency.
|
254
|
+
"""
|
255
|
+
# Skip if already fused
|
256
|
+
if getattr(self, "fused_projections", False):
|
257
|
+
return
|
258
|
+
|
259
|
+
device = self.to_q.weight.data.device
|
260
|
+
dtype = self.to_q.weight.data.dtype
|
261
|
+
|
262
|
+
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
263
|
+
# Fuse cross-attention key-value projections
|
264
|
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
265
|
+
in_features = concatenated_weights.shape[1]
|
266
|
+
out_features = concatenated_weights.shape[0]
|
267
|
+
|
268
|
+
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
269
|
+
self.to_kv.weight.copy_(concatenated_weights)
|
270
|
+
if hasattr(self, "use_bias") and self.use_bias:
|
271
|
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
272
|
+
self.to_kv.bias.copy_(concatenated_bias)
|
273
|
+
else:
|
274
|
+
# Fuse self-attention projections
|
275
|
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
276
|
+
in_features = concatenated_weights.shape[1]
|
277
|
+
out_features = concatenated_weights.shape[0]
|
278
|
+
|
279
|
+
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
280
|
+
self.to_qkv.weight.copy_(concatenated_weights)
|
281
|
+
if hasattr(self, "use_bias") and self.use_bias:
|
282
|
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
283
|
+
self.to_qkv.bias.copy_(concatenated_bias)
|
284
|
+
|
285
|
+
# Handle added projections for models like SD3, Flux, etc.
|
286
|
+
if (
|
287
|
+
getattr(self, "add_q_proj", None) is not None
|
288
|
+
and getattr(self, "add_k_proj", None) is not None
|
289
|
+
and getattr(self, "add_v_proj", None) is not None
|
290
|
+
):
|
291
|
+
concatenated_weights = torch.cat(
|
292
|
+
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
293
|
+
)
|
294
|
+
in_features = concatenated_weights.shape[1]
|
295
|
+
out_features = concatenated_weights.shape[0]
|
296
|
+
|
297
|
+
self.to_added_qkv = nn.Linear(
|
298
|
+
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
299
|
+
)
|
300
|
+
self.to_added_qkv.weight.copy_(concatenated_weights)
|
301
|
+
if self.added_proj_bias:
|
302
|
+
concatenated_bias = torch.cat(
|
303
|
+
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
304
|
+
)
|
305
|
+
self.to_added_qkv.bias.copy_(concatenated_bias)
|
306
|
+
|
307
|
+
self.fused_projections = True
|
308
|
+
|
309
|
+
@torch.no_grad()
|
310
|
+
def unfuse_projections(self):
|
311
|
+
"""
|
312
|
+
Unfuse the query, key, and value projections back to separate projections.
|
313
|
+
"""
|
314
|
+
# Skip if not fused
|
315
|
+
if not getattr(self, "fused_projections", False):
|
316
|
+
return
|
317
|
+
|
318
|
+
# Remove fused projection layers
|
319
|
+
if hasattr(self, "to_qkv"):
|
320
|
+
delattr(self, "to_qkv")
|
321
|
+
|
322
|
+
if hasattr(self, "to_kv"):
|
323
|
+
delattr(self, "to_kv")
|
324
|
+
|
325
|
+
if hasattr(self, "to_added_qkv"):
|
326
|
+
delattr(self, "to_added_qkv")
|
327
|
+
|
328
|
+
self.fused_projections = False
|
329
|
+
|
330
|
+
def set_attention_slice(self, slice_size: int) -> None:
|
331
|
+
"""
|
332
|
+
Set the slice size for attention computation.
|
333
|
+
|
334
|
+
Args:
|
335
|
+
slice_size (`int`):
|
336
|
+
The slice size for attention computation.
|
337
|
+
"""
|
338
|
+
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
|
339
|
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
340
|
+
|
341
|
+
processor = None
|
342
|
+
|
343
|
+
# Try to get a compatible processor for sliced attention
|
344
|
+
if slice_size is not None:
|
345
|
+
processor = self._get_compatible_processor("sliced")
|
346
|
+
|
347
|
+
# If no processor was found or slice_size is None, use default processor
|
348
|
+
if processor is None:
|
349
|
+
processor = self.default_processor_cls()
|
350
|
+
|
351
|
+
self.set_processor(processor)
|
352
|
+
|
353
|
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
354
|
+
"""
|
355
|
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
`torch.Tensor`: The reshaped tensor.
|
362
|
+
"""
|
363
|
+
head_size = self.heads
|
364
|
+
batch_size, seq_len, dim = tensor.shape
|
365
|
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
366
|
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
367
|
+
return tensor
|
368
|
+
|
369
|
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
370
|
+
"""
|
371
|
+
Reshape the tensor for multi-head attention processing.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
tensor (`torch.Tensor`): The tensor to reshape.
|
375
|
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
`torch.Tensor`: The reshaped tensor.
|
379
|
+
"""
|
380
|
+
head_size = self.heads
|
381
|
+
if tensor.ndim == 3:
|
382
|
+
batch_size, seq_len, dim = tensor.shape
|
383
|
+
extra_dim = 1
|
384
|
+
else:
|
385
|
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
386
|
+
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
387
|
+
tensor = tensor.permute(0, 2, 1, 3)
|
388
|
+
|
389
|
+
if out_dim == 3:
|
390
|
+
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
391
|
+
|
392
|
+
return tensor
|
393
|
+
|
394
|
+
def get_attention_scores(
|
395
|
+
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
396
|
+
) -> torch.Tensor:
|
397
|
+
"""
|
398
|
+
Compute the attention scores.
|
399
|
+
|
400
|
+
Args:
|
401
|
+
query (`torch.Tensor`): The query tensor.
|
402
|
+
key (`torch.Tensor`): The key tensor.
|
403
|
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
`torch.Tensor`: The attention probabilities/scores.
|
407
|
+
"""
|
408
|
+
dtype = query.dtype
|
409
|
+
if self.upcast_attention:
|
410
|
+
query = query.float()
|
411
|
+
key = key.float()
|
412
|
+
|
413
|
+
if attention_mask is None:
|
414
|
+
baddbmm_input = torch.empty(
|
415
|
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
416
|
+
)
|
417
|
+
beta = 0
|
418
|
+
else:
|
419
|
+
baddbmm_input = attention_mask
|
420
|
+
beta = 1
|
421
|
+
|
422
|
+
attention_scores = torch.baddbmm(
|
423
|
+
baddbmm_input,
|
424
|
+
query,
|
425
|
+
key.transpose(-1, -2),
|
426
|
+
beta=beta,
|
427
|
+
alpha=self.scale,
|
428
|
+
)
|
429
|
+
del baddbmm_input
|
430
|
+
|
431
|
+
if self.upcast_softmax:
|
432
|
+
attention_scores = attention_scores.float()
|
433
|
+
|
434
|
+
attention_probs = attention_scores.softmax(dim=-1)
|
435
|
+
del attention_scores
|
436
|
+
|
437
|
+
attention_probs = attention_probs.to(dtype)
|
438
|
+
|
439
|
+
return attention_probs
|
440
|
+
|
441
|
+
def prepare_attention_mask(
|
442
|
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
443
|
+
) -> torch.Tensor:
|
444
|
+
"""
|
445
|
+
Prepare the attention mask for the attention computation.
|
446
|
+
|
447
|
+
Args:
|
448
|
+
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
449
|
+
target_length (`int`): The target length of the attention mask.
|
450
|
+
batch_size (`int`): The batch size for repeating the attention mask.
|
451
|
+
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
`torch.Tensor`: The prepared attention mask.
|
455
|
+
"""
|
456
|
+
head_size = self.heads
|
457
|
+
if attention_mask is None:
|
458
|
+
return attention_mask
|
459
|
+
|
460
|
+
current_length: int = attention_mask.shape[-1]
|
461
|
+
if current_length != target_length:
|
462
|
+
if attention_mask.device.type == "mps":
|
463
|
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
464
|
+
# Instead, we can manually construct the padding tensor.
|
465
|
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
466
|
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
467
|
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
468
|
+
else:
|
469
|
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
470
|
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
471
|
+
# remaining_length: int = target_length - current_length
|
472
|
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
473
|
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
474
|
+
|
475
|
+
if out_dim == 3:
|
476
|
+
if attention_mask.shape[0] < batch_size * head_size:
|
477
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
478
|
+
elif out_dim == 4:
|
479
|
+
attention_mask = attention_mask.unsqueeze(1)
|
480
|
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
481
|
+
|
482
|
+
return attention_mask
|
483
|
+
|
484
|
+
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
485
|
+
"""
|
486
|
+
Normalize the encoder hidden states.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
490
|
+
|
491
|
+
Returns:
|
492
|
+
`torch.Tensor`: The normalized encoder hidden states.
|
493
|
+
"""
|
494
|
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
495
|
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
496
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
497
|
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
498
|
+
# Group norm norms along the channels dimension and expects
|
499
|
+
# input to be in the shape of (N, C, *). In this case, we want
|
500
|
+
# to norm along the hidden dimension, so we need to move
|
501
|
+
# (batch_size, sequence_length, hidden_size) ->
|
502
|
+
# (batch_size, hidden_size, sequence_length)
|
503
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
504
|
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
505
|
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
506
|
+
else:
|
507
|
+
assert False
|
508
|
+
|
509
|
+
return encoder_hidden_states
|
510
|
+
|
511
|
+
|
31
512
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32
513
|
# "feed_forward_chunk_size" can be used to save memory
|
33
514
|
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|