diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ import torch.nn.functional as F
|
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
23
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
|
-
from
|
26
|
-
from ..
|
24
|
+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
25
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
26
|
+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
27
|
+
from ..attention_dispatch import dispatch_attention_fn
|
27
28
|
from ..cache_utils import CacheMixin
|
28
29
|
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
29
30
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -34,18 +35,51 @@ from ..normalization import FP32LayerNorm
|
|
34
35
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
36
|
|
36
37
|
|
37
|
-
|
38
|
+
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
39
|
+
# encoder_hidden_states is only passed for cross-attention
|
40
|
+
if encoder_hidden_states is None:
|
41
|
+
encoder_hidden_states = hidden_states
|
42
|
+
|
43
|
+
if attn.fused_projections:
|
44
|
+
if attn.cross_attention_dim_head is None:
|
45
|
+
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
46
|
+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
47
|
+
else:
|
48
|
+
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
49
|
+
query = attn.to_q(hidden_states)
|
50
|
+
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
51
|
+
else:
|
52
|
+
query = attn.to_q(hidden_states)
|
53
|
+
key = attn.to_k(encoder_hidden_states)
|
54
|
+
value = attn.to_v(encoder_hidden_states)
|
55
|
+
return query, key, value
|
56
|
+
|
57
|
+
|
58
|
+
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
|
59
|
+
if attn.fused_projections:
|
60
|
+
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
|
61
|
+
else:
|
62
|
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
63
|
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
64
|
+
return key_img, value_img
|
65
|
+
|
66
|
+
|
67
|
+
class WanAttnProcessor:
|
68
|
+
_attention_backend = None
|
69
|
+
|
38
70
|
def __init__(self):
|
39
71
|
if not hasattr(F, "scaled_dot_product_attention"):
|
40
|
-
raise ImportError(
|
72
|
+
raise ImportError(
|
73
|
+
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
74
|
+
)
|
41
75
|
|
42
76
|
def __call__(
|
43
77
|
self,
|
44
|
-
attn:
|
78
|
+
attn: "WanAttention",
|
45
79
|
hidden_states: torch.Tensor,
|
46
80
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
47
81
|
attention_mask: Optional[torch.Tensor] = None,
|
48
|
-
rotary_emb: Optional[torch.Tensor] = None,
|
82
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
49
83
|
) -> torch.Tensor:
|
50
84
|
encoder_hidden_states_img = None
|
51
85
|
if attn.add_k_proj is not None:
|
@@ -53,53 +87,65 @@ class WanAttnProcessor2_0:
|
|
53
87
|
image_context_length = encoder_hidden_states.shape[1] - 512
|
54
88
|
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
55
89
|
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
56
|
-
if encoder_hidden_states is None:
|
57
|
-
encoder_hidden_states = hidden_states
|
58
90
|
|
59
|
-
query = attn
|
60
|
-
key = attn.to_k(encoder_hidden_states)
|
61
|
-
value = attn.to_v(encoder_hidden_states)
|
91
|
+
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
62
92
|
|
63
|
-
|
64
|
-
|
65
|
-
if attn.norm_k is not None:
|
66
|
-
key = attn.norm_k(key)
|
93
|
+
query = attn.norm_q(query)
|
94
|
+
key = attn.norm_k(key)
|
67
95
|
|
68
|
-
query = query.unflatten(2, (attn.heads, -1))
|
69
|
-
key = key.unflatten(2, (attn.heads, -1))
|
70
|
-
value = value.unflatten(2, (attn.heads, -1))
|
96
|
+
query = query.unflatten(2, (attn.heads, -1))
|
97
|
+
key = key.unflatten(2, (attn.heads, -1))
|
98
|
+
value = value.unflatten(2, (attn.heads, -1))
|
71
99
|
|
72
100
|
if rotary_emb is not None:
|
73
101
|
|
74
|
-
def apply_rotary_emb(
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
102
|
+
def apply_rotary_emb(
|
103
|
+
hidden_states: torch.Tensor,
|
104
|
+
freqs_cos: torch.Tensor,
|
105
|
+
freqs_sin: torch.Tensor,
|
106
|
+
):
|
107
|
+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
108
|
+
cos = freqs_cos[..., 0::2]
|
109
|
+
sin = freqs_sin[..., 1::2]
|
110
|
+
out = torch.empty_like(hidden_states)
|
111
|
+
out[..., 0::2] = x1 * cos - x2 * sin
|
112
|
+
out[..., 1::2] = x1 * sin + x2 * cos
|
113
|
+
return out.type_as(hidden_states)
|
114
|
+
|
115
|
+
query = apply_rotary_emb(query, *rotary_emb)
|
116
|
+
key = apply_rotary_emb(key, *rotary_emb)
|
82
117
|
|
83
118
|
# I2V task
|
84
119
|
hidden_states_img = None
|
85
120
|
if encoder_hidden_states_img is not None:
|
86
|
-
key_img = attn
|
121
|
+
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
|
87
122
|
key_img = attn.norm_added_k(key_img)
|
88
|
-
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
89
|
-
|
90
|
-
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
91
|
-
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
92
123
|
|
93
|
-
|
94
|
-
|
124
|
+
key_img = key_img.unflatten(2, (attn.heads, -1))
|
125
|
+
value_img = value_img.unflatten(2, (attn.heads, -1))
|
126
|
+
|
127
|
+
hidden_states_img = dispatch_attention_fn(
|
128
|
+
query,
|
129
|
+
key_img,
|
130
|
+
value_img,
|
131
|
+
attn_mask=None,
|
132
|
+
dropout_p=0.0,
|
133
|
+
is_causal=False,
|
134
|
+
backend=self._attention_backend,
|
95
135
|
)
|
96
|
-
hidden_states_img = hidden_states_img.
|
136
|
+
hidden_states_img = hidden_states_img.flatten(2, 3)
|
97
137
|
hidden_states_img = hidden_states_img.type_as(query)
|
98
138
|
|
99
|
-
hidden_states =
|
100
|
-
query,
|
139
|
+
hidden_states = dispatch_attention_fn(
|
140
|
+
query,
|
141
|
+
key,
|
142
|
+
value,
|
143
|
+
attn_mask=attention_mask,
|
144
|
+
dropout_p=0.0,
|
145
|
+
is_causal=False,
|
146
|
+
backend=self._attention_backend,
|
101
147
|
)
|
102
|
-
hidden_states = hidden_states.
|
148
|
+
hidden_states = hidden_states.flatten(2, 3)
|
103
149
|
hidden_states = hidden_states.type_as(query)
|
104
150
|
|
105
151
|
if hidden_states_img is not None:
|
@@ -110,6 +156,122 @@ class WanAttnProcessor2_0:
|
|
110
156
|
return hidden_states
|
111
157
|
|
112
158
|
|
159
|
+
class WanAttnProcessor2_0:
|
160
|
+
def __new__(cls, *args, **kwargs):
|
161
|
+
deprecation_message = (
|
162
|
+
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
|
163
|
+
"Please use WanAttnProcessor instead. "
|
164
|
+
)
|
165
|
+
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
166
|
+
return WanAttnProcessor(*args, **kwargs)
|
167
|
+
|
168
|
+
|
169
|
+
class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
170
|
+
_default_processor_cls = WanAttnProcessor
|
171
|
+
_available_processors = [WanAttnProcessor]
|
172
|
+
|
173
|
+
def __init__(
|
174
|
+
self,
|
175
|
+
dim: int,
|
176
|
+
heads: int = 8,
|
177
|
+
dim_head: int = 64,
|
178
|
+
eps: float = 1e-5,
|
179
|
+
dropout: float = 0.0,
|
180
|
+
added_kv_proj_dim: Optional[int] = None,
|
181
|
+
cross_attention_dim_head: Optional[int] = None,
|
182
|
+
processor=None,
|
183
|
+
is_cross_attention=None,
|
184
|
+
):
|
185
|
+
super().__init__()
|
186
|
+
|
187
|
+
self.inner_dim = dim_head * heads
|
188
|
+
self.heads = heads
|
189
|
+
self.added_kv_proj_dim = added_kv_proj_dim
|
190
|
+
self.cross_attention_dim_head = cross_attention_dim_head
|
191
|
+
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
192
|
+
|
193
|
+
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
194
|
+
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
195
|
+
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
196
|
+
self.to_out = torch.nn.ModuleList(
|
197
|
+
[
|
198
|
+
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
199
|
+
torch.nn.Dropout(dropout),
|
200
|
+
]
|
201
|
+
)
|
202
|
+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
203
|
+
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
204
|
+
|
205
|
+
self.add_k_proj = self.add_v_proj = None
|
206
|
+
if added_kv_proj_dim is not None:
|
207
|
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
208
|
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
209
|
+
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
210
|
+
|
211
|
+
self.is_cross_attention = cross_attention_dim_head is not None
|
212
|
+
|
213
|
+
self.set_processor(processor)
|
214
|
+
|
215
|
+
def fuse_projections(self):
|
216
|
+
if getattr(self, "fused_projections", False):
|
217
|
+
return
|
218
|
+
|
219
|
+
if self.cross_attention_dim_head is None:
|
220
|
+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
221
|
+
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
222
|
+
out_features, in_features = concatenated_weights.shape
|
223
|
+
with torch.device("meta"):
|
224
|
+
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
225
|
+
self.to_qkv.load_state_dict(
|
226
|
+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
227
|
+
)
|
228
|
+
else:
|
229
|
+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
230
|
+
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
231
|
+
out_features, in_features = concatenated_weights.shape
|
232
|
+
with torch.device("meta"):
|
233
|
+
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
234
|
+
self.to_kv.load_state_dict(
|
235
|
+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
236
|
+
)
|
237
|
+
|
238
|
+
if self.added_kv_proj_dim is not None:
|
239
|
+
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
240
|
+
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
241
|
+
out_features, in_features = concatenated_weights.shape
|
242
|
+
with torch.device("meta"):
|
243
|
+
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
244
|
+
self.to_added_kv.load_state_dict(
|
245
|
+
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
246
|
+
)
|
247
|
+
|
248
|
+
self.fused_projections = True
|
249
|
+
|
250
|
+
@torch.no_grad()
|
251
|
+
def unfuse_projections(self):
|
252
|
+
if not getattr(self, "fused_projections", False):
|
253
|
+
return
|
254
|
+
|
255
|
+
if hasattr(self, "to_qkv"):
|
256
|
+
delattr(self, "to_qkv")
|
257
|
+
if hasattr(self, "to_kv"):
|
258
|
+
delattr(self, "to_kv")
|
259
|
+
if hasattr(self, "to_added_kv"):
|
260
|
+
delattr(self, "to_added_kv")
|
261
|
+
|
262
|
+
self.fused_projections = False
|
263
|
+
|
264
|
+
def forward(
|
265
|
+
self,
|
266
|
+
hidden_states: torch.Tensor,
|
267
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
268
|
+
attention_mask: Optional[torch.Tensor] = None,
|
269
|
+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
270
|
+
**kwargs,
|
271
|
+
) -> torch.Tensor:
|
272
|
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
|
273
|
+
|
274
|
+
|
113
275
|
class WanImageEmbedding(torch.nn.Module):
|
114
276
|
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
115
277
|
super().__init__()
|
@@ -161,8 +323,11 @@ class WanTimeTextImageEmbedding(nn.Module):
|
|
161
323
|
timestep: torch.Tensor,
|
162
324
|
encoder_hidden_states: torch.Tensor,
|
163
325
|
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
326
|
+
timestep_seq_len: Optional[int] = None,
|
164
327
|
):
|
165
328
|
timestep = self.timesteps_proj(timestep)
|
329
|
+
if timestep_seq_len is not None:
|
330
|
+
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
166
331
|
|
167
332
|
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
168
333
|
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
@@ -179,7 +344,11 @@ class WanTimeTextImageEmbedding(nn.Module):
|
|
179
344
|
|
180
345
|
class WanRotaryPosEmbed(nn.Module):
|
181
346
|
def __init__(
|
182
|
-
self,
|
347
|
+
self,
|
348
|
+
attention_head_dim: int,
|
349
|
+
patch_size: Tuple[int, int, int],
|
350
|
+
max_seq_len: int,
|
351
|
+
theta: float = 10000.0,
|
183
352
|
):
|
184
353
|
super().__init__()
|
185
354
|
|
@@ -189,38 +358,55 @@ class WanRotaryPosEmbed(nn.Module):
|
|
189
358
|
|
190
359
|
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
191
360
|
t_dim = attention_head_dim - h_dim - w_dim
|
192
|
-
|
193
|
-
freqs = []
|
194
361
|
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
362
|
+
|
363
|
+
freqs_cos = []
|
364
|
+
freqs_sin = []
|
365
|
+
|
195
366
|
for dim in [t_dim, h_dim, w_dim]:
|
196
|
-
|
197
|
-
dim,
|
367
|
+
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
368
|
+
dim,
|
369
|
+
max_seq_len,
|
370
|
+
theta,
|
371
|
+
use_real=True,
|
372
|
+
repeat_interleave_real=True,
|
373
|
+
freqs_dtype=freqs_dtype,
|
198
374
|
)
|
199
|
-
|
200
|
-
|
375
|
+
freqs_cos.append(freq_cos)
|
376
|
+
freqs_sin.append(freq_sin)
|
377
|
+
|
378
|
+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
379
|
+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
201
380
|
|
202
381
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
203
382
|
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
204
383
|
p_t, p_h, p_w = self.patch_size
|
205
384
|
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
206
385
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
)
|
386
|
+
split_sizes = [
|
387
|
+
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
388
|
+
self.attention_head_dim // 3,
|
389
|
+
self.attention_head_dim // 3,
|
390
|
+
]
|
391
|
+
|
392
|
+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
393
|
+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
216
394
|
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
221
|
-
return freqs
|
395
|
+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
396
|
+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
397
|
+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
222
398
|
|
399
|
+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
400
|
+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
401
|
+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
223
402
|
|
403
|
+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
404
|
+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
405
|
+
|
406
|
+
return freqs_cos, freqs_sin
|
407
|
+
|
408
|
+
|
409
|
+
@maybe_allow_in_graph
|
224
410
|
class WanTransformerBlock(nn.Module):
|
225
411
|
def __init__(
|
226
412
|
self,
|
@@ -236,33 +422,24 @@ class WanTransformerBlock(nn.Module):
|
|
236
422
|
|
237
423
|
# 1. Self-attention
|
238
424
|
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
239
|
-
self.attn1 =
|
240
|
-
|
425
|
+
self.attn1 = WanAttention(
|
426
|
+
dim=dim,
|
241
427
|
heads=num_heads,
|
242
|
-
kv_heads=num_heads,
|
243
428
|
dim_head=dim // num_heads,
|
244
|
-
qk_norm=qk_norm,
|
245
429
|
eps=eps,
|
246
|
-
|
247
|
-
|
248
|
-
out_bias=True,
|
249
|
-
processor=WanAttnProcessor2_0(),
|
430
|
+
cross_attention_dim_head=None,
|
431
|
+
processor=WanAttnProcessor(),
|
250
432
|
)
|
251
433
|
|
252
434
|
# 2. Cross-attention
|
253
|
-
self.attn2 =
|
254
|
-
|
435
|
+
self.attn2 = WanAttention(
|
436
|
+
dim=dim,
|
255
437
|
heads=num_heads,
|
256
|
-
kv_heads=num_heads,
|
257
438
|
dim_head=dim // num_heads,
|
258
|
-
qk_norm=qk_norm,
|
259
439
|
eps=eps,
|
260
|
-
bias=True,
|
261
|
-
cross_attention_dim=None,
|
262
|
-
out_bias=True,
|
263
440
|
added_kv_proj_dim=added_kv_proj_dim,
|
264
|
-
|
265
|
-
processor=
|
441
|
+
cross_attention_dim_head=dim // num_heads,
|
442
|
+
processor=WanAttnProcessor(),
|
266
443
|
)
|
267
444
|
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
268
445
|
|
@@ -279,18 +456,32 @@ class WanTransformerBlock(nn.Module):
|
|
279
456
|
temb: torch.Tensor,
|
280
457
|
rotary_emb: torch.Tensor,
|
281
458
|
) -> torch.Tensor:
|
282
|
-
|
283
|
-
|
284
|
-
|
459
|
+
if temb.ndim == 4:
|
460
|
+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
461
|
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
462
|
+
self.scale_shift_table.unsqueeze(0) + temb.float()
|
463
|
+
).chunk(6, dim=2)
|
464
|
+
# batch_size, seq_len, 1, inner_dim
|
465
|
+
shift_msa = shift_msa.squeeze(2)
|
466
|
+
scale_msa = scale_msa.squeeze(2)
|
467
|
+
gate_msa = gate_msa.squeeze(2)
|
468
|
+
c_shift_msa = c_shift_msa.squeeze(2)
|
469
|
+
c_scale_msa = c_scale_msa.squeeze(2)
|
470
|
+
c_gate_msa = c_gate_msa.squeeze(2)
|
471
|
+
else:
|
472
|
+
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
473
|
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
474
|
+
self.scale_shift_table + temb.float()
|
475
|
+
).chunk(6, dim=1)
|
285
476
|
|
286
477
|
# 1. Self-attention
|
287
478
|
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
288
|
-
attn_output = self.attn1(
|
479
|
+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
289
480
|
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
290
481
|
|
291
482
|
# 2. Cross-attention
|
292
483
|
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
293
|
-
attn_output = self.attn2(
|
484
|
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
294
485
|
hidden_states = hidden_states + attn_output
|
295
486
|
|
296
487
|
# 3. Feed-forward
|
@@ -303,7 +494,9 @@ class WanTransformerBlock(nn.Module):
|
|
303
494
|
return hidden_states
|
304
495
|
|
305
496
|
|
306
|
-
class WanTransformer3DModel(
|
497
|
+
class WanTransformer3DModel(
|
498
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
499
|
+
):
|
307
500
|
r"""
|
308
501
|
A Transformer model for video-like data used in the Wan model.
|
309
502
|
|
@@ -345,6 +538,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
345
538
|
_no_split_modules = ["WanTransformerBlock"]
|
346
539
|
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
347
540
|
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
541
|
+
_repeated_blocks = ["WanTransformerBlock"]
|
348
542
|
|
349
543
|
@register_to_config
|
350
544
|
def __init__(
|
@@ -438,10 +632,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
438
632
|
hidden_states = self.patch_embedding(hidden_states)
|
439
633
|
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
440
634
|
|
635
|
+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
636
|
+
if timestep.ndim == 2:
|
637
|
+
ts_seq_len = timestep.shape[1]
|
638
|
+
timestep = timestep.flatten() # batch_size * seq_len
|
639
|
+
else:
|
640
|
+
ts_seq_len = None
|
641
|
+
|
441
642
|
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
442
|
-
timestep, encoder_hidden_states, encoder_hidden_states_image
|
643
|
+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
443
644
|
)
|
444
|
-
|
645
|
+
if ts_seq_len is not None:
|
646
|
+
# batch_size, seq_len, 6, inner_dim
|
647
|
+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
648
|
+
else:
|
649
|
+
# batch_size, 6, inner_dim
|
650
|
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
445
651
|
|
446
652
|
if encoder_hidden_states_image is not None:
|
447
653
|
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
@@ -457,7 +663,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
457
663
|
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
458
664
|
|
459
665
|
# 5. Output norm, projection & unpatchify
|
460
|
-
|
666
|
+
if temb.ndim == 3:
|
667
|
+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
668
|
+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
669
|
+
shift = shift.squeeze(2)
|
670
|
+
scale = scale.squeeze(2)
|
671
|
+
else:
|
672
|
+
# batch_size, inner_dim
|
673
|
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
461
674
|
|
462
675
|
# Move the shift and scale tensors to the same device as hidden_states.
|
463
676
|
# When using multi-GPU inference via accelerate these will be on the
|
@@ -22,12 +22,17 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
|
22
22
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
23
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
24
|
from ..attention import FeedForward
|
25
|
-
from ..attention_processor import Attention
|
26
25
|
from ..cache_utils import CacheMixin
|
27
26
|
from ..modeling_outputs import Transformer2DModelOutput
|
28
27
|
from ..modeling_utils import ModelMixin
|
29
28
|
from ..normalization import FP32LayerNorm
|
30
|
-
from .transformer_wan import
|
29
|
+
from .transformer_wan import (
|
30
|
+
WanAttention,
|
31
|
+
WanAttnProcessor,
|
32
|
+
WanRotaryPosEmbed,
|
33
|
+
WanTimeTextImageEmbedding,
|
34
|
+
WanTransformerBlock,
|
35
|
+
)
|
31
36
|
|
32
37
|
|
33
38
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -55,33 +60,22 @@ class WanVACETransformerBlock(nn.Module):
|
|
55
60
|
|
56
61
|
# 2. Self-attention
|
57
62
|
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
58
|
-
self.attn1 =
|
59
|
-
|
63
|
+
self.attn1 = WanAttention(
|
64
|
+
dim=dim,
|
60
65
|
heads=num_heads,
|
61
|
-
kv_heads=num_heads,
|
62
66
|
dim_head=dim // num_heads,
|
63
|
-
qk_norm=qk_norm,
|
64
67
|
eps=eps,
|
65
|
-
|
66
|
-
cross_attention_dim=None,
|
67
|
-
out_bias=True,
|
68
|
-
processor=WanAttnProcessor2_0(),
|
68
|
+
processor=WanAttnProcessor(),
|
69
69
|
)
|
70
70
|
|
71
71
|
# 3. Cross-attention
|
72
|
-
self.attn2 =
|
73
|
-
|
72
|
+
self.attn2 = WanAttention(
|
73
|
+
dim=dim,
|
74
74
|
heads=num_heads,
|
75
|
-
kv_heads=num_heads,
|
76
75
|
dim_head=dim // num_heads,
|
77
|
-
qk_norm=qk_norm,
|
78
76
|
eps=eps,
|
79
|
-
bias=True,
|
80
|
-
cross_attention_dim=None,
|
81
|
-
out_bias=True,
|
82
77
|
added_kv_proj_dim=added_kv_proj_dim,
|
83
|
-
|
84
|
-
processor=WanAttnProcessor2_0(),
|
78
|
+
processor=WanAttnProcessor(),
|
85
79
|
)
|
86
80
|
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
87
81
|
|
@@ -116,12 +110,12 @@ class WanVACETransformerBlock(nn.Module):
|
|
116
110
|
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
117
111
|
control_hidden_states
|
118
112
|
)
|
119
|
-
attn_output = self.attn1(
|
113
|
+
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
120
114
|
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
121
115
|
|
122
116
|
# 2. Cross-attention
|
123
117
|
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
124
|
-
attn_output = self.attn2(
|
118
|
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
125
119
|
control_hidden_states = control_hidden_states + attn_output
|
126
120
|
|
127
121
|
# 3. Feed-forward
|
@@ -165,8 +165,9 @@ class UNet2DConditionModel(
|
|
165
165
|
"""
|
166
166
|
|
167
167
|
_supports_gradient_checkpointing = True
|
168
|
-
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
168
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
|
169
169
|
_skip_layerwise_casting_patterns = ["norm"]
|
170
|
+
_repeated_blocks = ["BasicTransformerBlock"]
|
170
171
|
|
171
172
|
@register_to_config
|
172
173
|
def __init__(
|