diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ import torch.nn.functional as F
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
23
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
- from ..attention import FeedForward
26
- from ..attention_processor import Attention
24
+ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils.torch_utils import maybe_allow_in_graph
26
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
27
28
  from ..cache_utils import CacheMixin
28
29
  from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
29
30
  from ..modeling_outputs import Transformer2DModelOutput
@@ -34,18 +35,51 @@ from ..normalization import FP32LayerNorm
34
35
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
36
 
36
37
 
37
- class WanAttnProcessor2_0:
38
+ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None:
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn.fused_projections:
44
+ if attn.cross_attention_dim_head is None:
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47
+ else:
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn.to_q(hidden_states)
50
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51
+ else:
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+ return query, key, value
56
+
57
+
58
+ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59
+ if attn.fused_projections:
60
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61
+ else:
62
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
63
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
64
+ return key_img, value_img
65
+
66
+
67
+ class WanAttnProcessor:
68
+ _attention_backend = None
69
+
38
70
  def __init__(self):
39
71
  if not hasattr(F, "scaled_dot_product_attention"):
40
- raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
72
+ raise ImportError(
73
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74
+ )
41
75
 
42
76
  def __call__(
43
77
  self,
44
- attn: Attention,
78
+ attn: "WanAttention",
45
79
  hidden_states: torch.Tensor,
46
80
  encoder_hidden_states: Optional[torch.Tensor] = None,
47
81
  attention_mask: Optional[torch.Tensor] = None,
48
- rotary_emb: Optional[torch.Tensor] = None,
82
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
49
83
  ) -> torch.Tensor:
50
84
  encoder_hidden_states_img = None
51
85
  if attn.add_k_proj is not None:
@@ -53,53 +87,65 @@ class WanAttnProcessor2_0:
53
87
  image_context_length = encoder_hidden_states.shape[1] - 512
54
88
  encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
55
89
  encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
56
- if encoder_hidden_states is None:
57
- encoder_hidden_states = hidden_states
58
90
 
59
- query = attn.to_q(hidden_states)
60
- key = attn.to_k(encoder_hidden_states)
61
- value = attn.to_v(encoder_hidden_states)
91
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
62
92
 
63
- if attn.norm_q is not None:
64
- query = attn.norm_q(query)
65
- if attn.norm_k is not None:
66
- key = attn.norm_k(key)
93
+ query = attn.norm_q(query)
94
+ key = attn.norm_k(key)
67
95
 
68
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
69
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
70
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
96
+ query = query.unflatten(2, (attn.heads, -1))
97
+ key = key.unflatten(2, (attn.heads, -1))
98
+ value = value.unflatten(2, (attn.heads, -1))
71
99
 
72
100
  if rotary_emb is not None:
73
101
 
74
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
75
- dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
76
- x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
77
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
78
- return x_out.type_as(hidden_states)
79
-
80
- query = apply_rotary_emb(query, rotary_emb)
81
- key = apply_rotary_emb(key, rotary_emb)
102
+ def apply_rotary_emb(
103
+ hidden_states: torch.Tensor,
104
+ freqs_cos: torch.Tensor,
105
+ freqs_sin: torch.Tensor,
106
+ ):
107
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
108
+ cos = freqs_cos[..., 0::2]
109
+ sin = freqs_sin[..., 1::2]
110
+ out = torch.empty_like(hidden_states)
111
+ out[..., 0::2] = x1 * cos - x2 * sin
112
+ out[..., 1::2] = x1 * sin + x2 * cos
113
+ return out.type_as(hidden_states)
114
+
115
+ query = apply_rotary_emb(query, *rotary_emb)
116
+ key = apply_rotary_emb(key, *rotary_emb)
82
117
 
83
118
  # I2V task
84
119
  hidden_states_img = None
85
120
  if encoder_hidden_states_img is not None:
86
- key_img = attn.add_k_proj(encoder_hidden_states_img)
121
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
87
122
  key_img = attn.norm_added_k(key_img)
88
- value_img = attn.add_v_proj(encoder_hidden_states_img)
89
-
90
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
91
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
92
123
 
93
- hidden_states_img = F.scaled_dot_product_attention(
94
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
124
+ key_img = key_img.unflatten(2, (attn.heads, -1))
125
+ value_img = value_img.unflatten(2, (attn.heads, -1))
126
+
127
+ hidden_states_img = dispatch_attention_fn(
128
+ query,
129
+ key_img,
130
+ value_img,
131
+ attn_mask=None,
132
+ dropout_p=0.0,
133
+ is_causal=False,
134
+ backend=self._attention_backend,
95
135
  )
96
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
136
+ hidden_states_img = hidden_states_img.flatten(2, 3)
97
137
  hidden_states_img = hidden_states_img.type_as(query)
98
138
 
99
- hidden_states = F.scaled_dot_product_attention(
100
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
139
+ hidden_states = dispatch_attention_fn(
140
+ query,
141
+ key,
142
+ value,
143
+ attn_mask=attention_mask,
144
+ dropout_p=0.0,
145
+ is_causal=False,
146
+ backend=self._attention_backend,
101
147
  )
102
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
148
+ hidden_states = hidden_states.flatten(2, 3)
103
149
  hidden_states = hidden_states.type_as(query)
104
150
 
105
151
  if hidden_states_img is not None:
@@ -110,6 +156,122 @@ class WanAttnProcessor2_0:
110
156
  return hidden_states
111
157
 
112
158
 
159
+ class WanAttnProcessor2_0:
160
+ def __new__(cls, *args, **kwargs):
161
+ deprecation_message = (
162
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163
+ "Please use WanAttnProcessor instead. "
164
+ )
165
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
166
+ return WanAttnProcessor(*args, **kwargs)
167
+
168
+
169
+ class WanAttention(torch.nn.Module, AttentionModuleMixin):
170
+ _default_processor_cls = WanAttnProcessor
171
+ _available_processors = [WanAttnProcessor]
172
+
173
+ def __init__(
174
+ self,
175
+ dim: int,
176
+ heads: int = 8,
177
+ dim_head: int = 64,
178
+ eps: float = 1e-5,
179
+ dropout: float = 0.0,
180
+ added_kv_proj_dim: Optional[int] = None,
181
+ cross_attention_dim_head: Optional[int] = None,
182
+ processor=None,
183
+ is_cross_attention=None,
184
+ ):
185
+ super().__init__()
186
+
187
+ self.inner_dim = dim_head * heads
188
+ self.heads = heads
189
+ self.added_kv_proj_dim = added_kv_proj_dim
190
+ self.cross_attention_dim_head = cross_attention_dim_head
191
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
192
+
193
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
194
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
195
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
196
+ self.to_out = torch.nn.ModuleList(
197
+ [
198
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
199
+ torch.nn.Dropout(dropout),
200
+ ]
201
+ )
202
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
203
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
204
+
205
+ self.add_k_proj = self.add_v_proj = None
206
+ if added_kv_proj_dim is not None:
207
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
208
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
209
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
210
+
211
+ self.is_cross_attention = cross_attention_dim_head is not None
212
+
213
+ self.set_processor(processor)
214
+
215
+ def fuse_projections(self):
216
+ if getattr(self, "fused_projections", False):
217
+ return
218
+
219
+ if self.cross_attention_dim_head is None:
220
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
221
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
222
+ out_features, in_features = concatenated_weights.shape
223
+ with torch.device("meta"):
224
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
225
+ self.to_qkv.load_state_dict(
226
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
227
+ )
228
+ else:
229
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
230
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
231
+ out_features, in_features = concatenated_weights.shape
232
+ with torch.device("meta"):
233
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
234
+ self.to_kv.load_state_dict(
235
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
236
+ )
237
+
238
+ if self.added_kv_proj_dim is not None:
239
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
240
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
241
+ out_features, in_features = concatenated_weights.shape
242
+ with torch.device("meta"):
243
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
244
+ self.to_added_kv.load_state_dict(
245
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
246
+ )
247
+
248
+ self.fused_projections = True
249
+
250
+ @torch.no_grad()
251
+ def unfuse_projections(self):
252
+ if not getattr(self, "fused_projections", False):
253
+ return
254
+
255
+ if hasattr(self, "to_qkv"):
256
+ delattr(self, "to_qkv")
257
+ if hasattr(self, "to_kv"):
258
+ delattr(self, "to_kv")
259
+ if hasattr(self, "to_added_kv"):
260
+ delattr(self, "to_added_kv")
261
+
262
+ self.fused_projections = False
263
+
264
+ def forward(
265
+ self,
266
+ hidden_states: torch.Tensor,
267
+ encoder_hidden_states: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.Tensor] = None,
269
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
270
+ **kwargs,
271
+ ) -> torch.Tensor:
272
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
273
+
274
+
113
275
  class WanImageEmbedding(torch.nn.Module):
114
276
  def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
115
277
  super().__init__()
@@ -161,8 +323,11 @@ class WanTimeTextImageEmbedding(nn.Module):
161
323
  timestep: torch.Tensor,
162
324
  encoder_hidden_states: torch.Tensor,
163
325
  encoder_hidden_states_image: Optional[torch.Tensor] = None,
326
+ timestep_seq_len: Optional[int] = None,
164
327
  ):
165
328
  timestep = self.timesteps_proj(timestep)
329
+ if timestep_seq_len is not None:
330
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
166
331
 
167
332
  time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
168
333
  if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -179,7 +344,11 @@ class WanTimeTextImageEmbedding(nn.Module):
179
344
 
180
345
  class WanRotaryPosEmbed(nn.Module):
181
346
  def __init__(
182
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
347
+ self,
348
+ attention_head_dim: int,
349
+ patch_size: Tuple[int, int, int],
350
+ max_seq_len: int,
351
+ theta: float = 10000.0,
183
352
  ):
184
353
  super().__init__()
185
354
 
@@ -189,38 +358,55 @@ class WanRotaryPosEmbed(nn.Module):
189
358
 
190
359
  h_dim = w_dim = 2 * (attention_head_dim // 6)
191
360
  t_dim = attention_head_dim - h_dim - w_dim
192
-
193
- freqs = []
194
361
  freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
362
+
363
+ freqs_cos = []
364
+ freqs_sin = []
365
+
195
366
  for dim in [t_dim, h_dim, w_dim]:
196
- freq = get_1d_rotary_pos_embed(
197
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
367
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
368
+ dim,
369
+ max_seq_len,
370
+ theta,
371
+ use_real=True,
372
+ repeat_interleave_real=True,
373
+ freqs_dtype=freqs_dtype,
198
374
  )
199
- freqs.append(freq)
200
- self.freqs = torch.cat(freqs, dim=1)
375
+ freqs_cos.append(freq_cos)
376
+ freqs_sin.append(freq_sin)
377
+
378
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
379
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
201
380
 
202
381
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
203
382
  batch_size, num_channels, num_frames, height, width = hidden_states.shape
204
383
  p_t, p_h, p_w = self.patch_size
205
384
  ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
206
385
 
207
- freqs = self.freqs.to(hidden_states.device)
208
- freqs = freqs.split_with_sizes(
209
- [
210
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
211
- self.attention_head_dim // 6,
212
- self.attention_head_dim // 6,
213
- ],
214
- dim=1,
215
- )
386
+ split_sizes = [
387
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
388
+ self.attention_head_dim // 3,
389
+ self.attention_head_dim // 3,
390
+ ]
391
+
392
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
393
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
216
394
 
217
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
218
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
219
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
220
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
221
- return freqs
395
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
396
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
397
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
222
398
 
399
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
400
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
401
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
223
402
 
403
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
404
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
405
+
406
+ return freqs_cos, freqs_sin
407
+
408
+
409
+ @maybe_allow_in_graph
224
410
  class WanTransformerBlock(nn.Module):
225
411
  def __init__(
226
412
  self,
@@ -236,33 +422,24 @@ class WanTransformerBlock(nn.Module):
236
422
 
237
423
  # 1. Self-attention
238
424
  self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
239
- self.attn1 = Attention(
240
- query_dim=dim,
425
+ self.attn1 = WanAttention(
426
+ dim=dim,
241
427
  heads=num_heads,
242
- kv_heads=num_heads,
243
428
  dim_head=dim // num_heads,
244
- qk_norm=qk_norm,
245
429
  eps=eps,
246
- bias=True,
247
- cross_attention_dim=None,
248
- out_bias=True,
249
- processor=WanAttnProcessor2_0(),
430
+ cross_attention_dim_head=None,
431
+ processor=WanAttnProcessor(),
250
432
  )
251
433
 
252
434
  # 2. Cross-attention
253
- self.attn2 = Attention(
254
- query_dim=dim,
435
+ self.attn2 = WanAttention(
436
+ dim=dim,
255
437
  heads=num_heads,
256
- kv_heads=num_heads,
257
438
  dim_head=dim // num_heads,
258
- qk_norm=qk_norm,
259
439
  eps=eps,
260
- bias=True,
261
- cross_attention_dim=None,
262
- out_bias=True,
263
440
  added_kv_proj_dim=added_kv_proj_dim,
264
- added_proj_bias=True,
265
- processor=WanAttnProcessor2_0(),
441
+ cross_attention_dim_head=dim // num_heads,
442
+ processor=WanAttnProcessor(),
266
443
  )
267
444
  self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
268
445
 
@@ -279,18 +456,32 @@ class WanTransformerBlock(nn.Module):
279
456
  temb: torch.Tensor,
280
457
  rotary_emb: torch.Tensor,
281
458
  ) -> torch.Tensor:
282
- shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
283
- self.scale_shift_table + temb.float()
284
- ).chunk(6, dim=1)
459
+ if temb.ndim == 4:
460
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
461
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
462
+ self.scale_shift_table.unsqueeze(0) + temb.float()
463
+ ).chunk(6, dim=2)
464
+ # batch_size, seq_len, 1, inner_dim
465
+ shift_msa = shift_msa.squeeze(2)
466
+ scale_msa = scale_msa.squeeze(2)
467
+ gate_msa = gate_msa.squeeze(2)
468
+ c_shift_msa = c_shift_msa.squeeze(2)
469
+ c_scale_msa = c_scale_msa.squeeze(2)
470
+ c_gate_msa = c_gate_msa.squeeze(2)
471
+ else:
472
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
473
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
474
+ self.scale_shift_table + temb.float()
475
+ ).chunk(6, dim=1)
285
476
 
286
477
  # 1. Self-attention
287
478
  norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
288
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
479
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
289
480
  hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
290
481
 
291
482
  # 2. Cross-attention
292
483
  norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
293
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
484
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
294
485
  hidden_states = hidden_states + attn_output
295
486
 
296
487
  # 3. Feed-forward
@@ -303,7 +494,9 @@ class WanTransformerBlock(nn.Module):
303
494
  return hidden_states
304
495
 
305
496
 
306
- class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
497
+ class WanTransformer3DModel(
498
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
499
+ ):
307
500
  r"""
308
501
  A Transformer model for video-like data used in the Wan model.
309
502
 
@@ -345,6 +538,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
345
538
  _no_split_modules = ["WanTransformerBlock"]
346
539
  _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
347
540
  _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
541
+ _repeated_blocks = ["WanTransformerBlock"]
348
542
 
349
543
  @register_to_config
350
544
  def __init__(
@@ -438,10 +632,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
438
632
  hidden_states = self.patch_embedding(hidden_states)
439
633
  hidden_states = hidden_states.flatten(2).transpose(1, 2)
440
634
 
635
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
636
+ if timestep.ndim == 2:
637
+ ts_seq_len = timestep.shape[1]
638
+ timestep = timestep.flatten() # batch_size * seq_len
639
+ else:
640
+ ts_seq_len = None
641
+
441
642
  temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
442
- timestep, encoder_hidden_states, encoder_hidden_states_image
643
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
443
644
  )
444
- timestep_proj = timestep_proj.unflatten(1, (6, -1))
645
+ if ts_seq_len is not None:
646
+ # batch_size, seq_len, 6, inner_dim
647
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
648
+ else:
649
+ # batch_size, 6, inner_dim
650
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
445
651
 
446
652
  if encoder_hidden_states_image is not None:
447
653
  encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -457,7 +663,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
457
663
  hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
458
664
 
459
665
  # 5. Output norm, projection & unpatchify
460
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
666
+ if temb.ndim == 3:
667
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
668
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
669
+ shift = shift.squeeze(2)
670
+ scale = scale.squeeze(2)
671
+ else:
672
+ # batch_size, inner_dim
673
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
461
674
 
462
675
  # Move the shift and scale tensors to the same device as hidden_states.
463
676
  # When using multi-GPU inference via accelerate these will be on the
@@ -22,12 +22,17 @@ from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
23
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
24
  from ..attention import FeedForward
25
- from ..attention_processor import Attention
26
25
  from ..cache_utils import CacheMixin
27
26
  from ..modeling_outputs import Transformer2DModelOutput
28
27
  from ..modeling_utils import ModelMixin
29
28
  from ..normalization import FP32LayerNorm
30
- from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
29
+ from .transformer_wan import (
30
+ WanAttention,
31
+ WanAttnProcessor,
32
+ WanRotaryPosEmbed,
33
+ WanTimeTextImageEmbedding,
34
+ WanTransformerBlock,
35
+ )
31
36
 
32
37
 
33
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,33 +60,22 @@ class WanVACETransformerBlock(nn.Module):
55
60
 
56
61
  # 2. Self-attention
57
62
  self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
58
- self.attn1 = Attention(
59
- query_dim=dim,
63
+ self.attn1 = WanAttention(
64
+ dim=dim,
60
65
  heads=num_heads,
61
- kv_heads=num_heads,
62
66
  dim_head=dim // num_heads,
63
- qk_norm=qk_norm,
64
67
  eps=eps,
65
- bias=True,
66
- cross_attention_dim=None,
67
- out_bias=True,
68
- processor=WanAttnProcessor2_0(),
68
+ processor=WanAttnProcessor(),
69
69
  )
70
70
 
71
71
  # 3. Cross-attention
72
- self.attn2 = Attention(
73
- query_dim=dim,
72
+ self.attn2 = WanAttention(
73
+ dim=dim,
74
74
  heads=num_heads,
75
- kv_heads=num_heads,
76
75
  dim_head=dim // num_heads,
77
- qk_norm=qk_norm,
78
76
  eps=eps,
79
- bias=True,
80
- cross_attention_dim=None,
81
- out_bias=True,
82
77
  added_kv_proj_dim=added_kv_proj_dim,
83
- added_proj_bias=True,
84
- processor=WanAttnProcessor2_0(),
78
+ processor=WanAttnProcessor(),
85
79
  )
86
80
  self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
87
81
 
@@ -116,12 +110,12 @@ class WanVACETransformerBlock(nn.Module):
116
110
  norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
117
111
  control_hidden_states
118
112
  )
119
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
113
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
120
114
  control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
121
115
 
122
116
  # 2. Cross-attention
123
117
  norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
124
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
118
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
125
119
  control_hidden_states = control_hidden_states + attn_output
126
120
 
127
121
  # 3. Feed-forward
@@ -165,8 +165,9 @@ class UNet2DConditionModel(
165
165
  """
166
166
 
167
167
  _supports_gradient_checkpointing = True
168
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
168
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
169
169
  _skip_layerwise_casting_patterns = ["norm"]
170
+ _repeated_blocks = ["BasicTransformerBlock"]
170
171
 
171
172
  @register_to_config
172
173
  def __init__(