diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,789 @@
1
+ # Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from diffusers.loaders import FromOriginalModelMixin
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...loaders import PeftAdapterMixin
25
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
+ from ..attention import FeedForward
27
+ from ..attention_processor import Attention, AttentionProcessor
28
+ from ..embeddings import (
29
+ CombinedTimestepGuidanceTextProjEmbeddings,
30
+ CombinedTimestepTextProjEmbeddings,
31
+ get_1d_rotary_pos_embed,
32
+ )
33
+ from ..modeling_outputs import Transformer2DModelOutput
34
+ from ..modeling_utils import ModelMixin
35
+ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class HunyuanVideoAttnProcessor2_0:
42
+ def __init__(self):
43
+ if not hasattr(F, "scaled_dot_product_attention"):
44
+ raise ImportError(
45
+ "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
46
+ )
47
+
48
+ def __call__(
49
+ self,
50
+ attn: Attention,
51
+ hidden_states: torch.Tensor,
52
+ encoder_hidden_states: Optional[torch.Tensor] = None,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ image_rotary_emb: Optional[torch.Tensor] = None,
55
+ ) -> torch.Tensor:
56
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
57
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
58
+
59
+ # 1. QKV projections
60
+ query = attn.to_q(hidden_states)
61
+ key = attn.to_k(hidden_states)
62
+ value = attn.to_v(hidden_states)
63
+
64
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
65
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
66
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
+
68
+ # 2. QK normalization
69
+ if attn.norm_q is not None:
70
+ query = attn.norm_q(query)
71
+ if attn.norm_k is not None:
72
+ key = attn.norm_k(key)
73
+
74
+ # 3. Rotational positional embeddings applied to latent stream
75
+ if image_rotary_emb is not None:
76
+ from ..embeddings import apply_rotary_emb
77
+
78
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
79
+ query = torch.cat(
80
+ [
81
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
82
+ query[:, :, -encoder_hidden_states.shape[1] :],
83
+ ],
84
+ dim=2,
85
+ )
86
+ key = torch.cat(
87
+ [
88
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
89
+ key[:, :, -encoder_hidden_states.shape[1] :],
90
+ ],
91
+ dim=2,
92
+ )
93
+ else:
94
+ query = apply_rotary_emb(query, image_rotary_emb)
95
+ key = apply_rotary_emb(key, image_rotary_emb)
96
+
97
+ # 4. Encoder condition QKV projection and normalization
98
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
99
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
100
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
101
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
102
+
103
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
104
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
105
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
106
+
107
+ if attn.norm_added_q is not None:
108
+ encoder_query = attn.norm_added_q(encoder_query)
109
+ if attn.norm_added_k is not None:
110
+ encoder_key = attn.norm_added_k(encoder_key)
111
+
112
+ query = torch.cat([query, encoder_query], dim=2)
113
+ key = torch.cat([key, encoder_key], dim=2)
114
+ value = torch.cat([value, encoder_value], dim=2)
115
+
116
+ # 5. Attention
117
+ hidden_states = F.scaled_dot_product_attention(
118
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
119
+ )
120
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
121
+ hidden_states = hidden_states.to(query.dtype)
122
+
123
+ # 6. Output projection
124
+ if encoder_hidden_states is not None:
125
+ hidden_states, encoder_hidden_states = (
126
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
127
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
128
+ )
129
+
130
+ if getattr(attn, "to_out", None) is not None:
131
+ hidden_states = attn.to_out[0](hidden_states)
132
+ hidden_states = attn.to_out[1](hidden_states)
133
+
134
+ if getattr(attn, "to_add_out", None) is not None:
135
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
136
+
137
+ return hidden_states, encoder_hidden_states
138
+
139
+
140
+ class HunyuanVideoPatchEmbed(nn.Module):
141
+ def __init__(
142
+ self,
143
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
144
+ in_chans: int = 3,
145
+ embed_dim: int = 768,
146
+ ) -> None:
147
+ super().__init__()
148
+
149
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
150
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
151
+
152
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153
+ hidden_states = self.proj(hidden_states)
154
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
155
+ return hidden_states
156
+
157
+
158
+ class HunyuanVideoAdaNorm(nn.Module):
159
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
160
+ super().__init__()
161
+
162
+ out_features = out_features or 2 * in_features
163
+ self.linear = nn.Linear(in_features, out_features)
164
+ self.nonlinearity = nn.SiLU()
165
+
166
+ def forward(
167
+ self, temb: torch.Tensor
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ temb = self.linear(self.nonlinearity(temb))
170
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
171
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
172
+ return gate_msa, gate_mlp
173
+
174
+
175
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
176
+ def __init__(
177
+ self,
178
+ num_attention_heads: int,
179
+ attention_head_dim: int,
180
+ mlp_width_ratio: str = 4.0,
181
+ mlp_drop_rate: float = 0.0,
182
+ attention_bias: bool = True,
183
+ ) -> None:
184
+ super().__init__()
185
+
186
+ hidden_size = num_attention_heads * attention_head_dim
187
+
188
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
189
+ self.attn = Attention(
190
+ query_dim=hidden_size,
191
+ cross_attention_dim=None,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ bias=attention_bias,
195
+ )
196
+
197
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
198
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
199
+
200
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ temb: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ ) -> torch.Tensor:
208
+ norm_hidden_states = self.norm1(hidden_states)
209
+
210
+ attn_output = self.attn(
211
+ hidden_states=norm_hidden_states,
212
+ encoder_hidden_states=None,
213
+ attention_mask=attention_mask,
214
+ )
215
+
216
+ gate_msa, gate_mlp = self.norm_out(temb)
217
+ hidden_states = hidden_states + attn_output * gate_msa
218
+
219
+ ff_output = self.ff(self.norm2(hidden_states))
220
+ hidden_states = hidden_states + ff_output * gate_mlp
221
+
222
+ return hidden_states
223
+
224
+
225
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
226
+ def __init__(
227
+ self,
228
+ num_attention_heads: int,
229
+ attention_head_dim: int,
230
+ num_layers: int,
231
+ mlp_width_ratio: float = 4.0,
232
+ mlp_drop_rate: float = 0.0,
233
+ attention_bias: bool = True,
234
+ ) -> None:
235
+ super().__init__()
236
+
237
+ self.refiner_blocks = nn.ModuleList(
238
+ [
239
+ HunyuanVideoIndividualTokenRefinerBlock(
240
+ num_attention_heads=num_attention_heads,
241
+ attention_head_dim=attention_head_dim,
242
+ mlp_width_ratio=mlp_width_ratio,
243
+ mlp_drop_rate=mlp_drop_rate,
244
+ attention_bias=attention_bias,
245
+ )
246
+ for _ in range(num_layers)
247
+ ]
248
+ )
249
+
250
+ def forward(
251
+ self,
252
+ hidden_states: torch.Tensor,
253
+ temb: torch.Tensor,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ ) -> None:
256
+ self_attn_mask = None
257
+ if attention_mask is not None:
258
+ batch_size = attention_mask.shape[0]
259
+ seq_len = attention_mask.shape[1]
260
+ attention_mask = attention_mask.to(hidden_states.device).bool()
261
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
262
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
263
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
264
+ self_attn_mask[:, :, :, 0] = True
265
+
266
+ for block in self.refiner_blocks:
267
+ hidden_states = block(hidden_states, temb, self_attn_mask)
268
+
269
+ return hidden_states
270
+
271
+
272
+ class HunyuanVideoTokenRefiner(nn.Module):
273
+ def __init__(
274
+ self,
275
+ in_channels: int,
276
+ num_attention_heads: int,
277
+ attention_head_dim: int,
278
+ num_layers: int,
279
+ mlp_ratio: float = 4.0,
280
+ mlp_drop_rate: float = 0.0,
281
+ attention_bias: bool = True,
282
+ ) -> None:
283
+ super().__init__()
284
+
285
+ hidden_size = num_attention_heads * attention_head_dim
286
+
287
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
288
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
289
+ )
290
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
291
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
292
+ num_attention_heads=num_attention_heads,
293
+ attention_head_dim=attention_head_dim,
294
+ num_layers=num_layers,
295
+ mlp_width_ratio=mlp_ratio,
296
+ mlp_drop_rate=mlp_drop_rate,
297
+ attention_bias=attention_bias,
298
+ )
299
+
300
+ def forward(
301
+ self,
302
+ hidden_states: torch.Tensor,
303
+ timestep: torch.LongTensor,
304
+ attention_mask: Optional[torch.LongTensor] = None,
305
+ ) -> torch.Tensor:
306
+ if attention_mask is None:
307
+ pooled_projections = hidden_states.mean(dim=1)
308
+ else:
309
+ original_dtype = hidden_states.dtype
310
+ mask_float = attention_mask.float().unsqueeze(-1)
311
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
312
+ pooled_projections = pooled_projections.to(original_dtype)
313
+
314
+ temb = self.time_text_embed(timestep, pooled_projections)
315
+ hidden_states = self.proj_in(hidden_states)
316
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
317
+
318
+ return hidden_states
319
+
320
+
321
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
322
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
323
+ super().__init__()
324
+
325
+ self.patch_size = patch_size
326
+ self.patch_size_t = patch_size_t
327
+ self.rope_dim = rope_dim
328
+ self.theta = theta
329
+
330
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
331
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
332
+ rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
333
+
334
+ axes_grids = []
335
+ for i in range(3):
336
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
337
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
338
+ # differences in layerwise debugging outputs, but visually it is the same.
339
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
340
+ axes_grids.append(grid)
341
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
342
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
343
+
344
+ freqs = []
345
+ for i in range(3):
346
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
347
+ freqs.append(freq)
348
+
349
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
350
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
351
+ return freqs_cos, freqs_sin
352
+
353
+
354
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
355
+ def __init__(
356
+ self,
357
+ num_attention_heads: int,
358
+ attention_head_dim: int,
359
+ mlp_ratio: float = 4.0,
360
+ qk_norm: str = "rms_norm",
361
+ ) -> None:
362
+ super().__init__()
363
+
364
+ hidden_size = num_attention_heads * attention_head_dim
365
+ mlp_dim = int(hidden_size * mlp_ratio)
366
+
367
+ self.attn = Attention(
368
+ query_dim=hidden_size,
369
+ cross_attention_dim=None,
370
+ dim_head=attention_head_dim,
371
+ heads=num_attention_heads,
372
+ out_dim=hidden_size,
373
+ bias=True,
374
+ processor=HunyuanVideoAttnProcessor2_0(),
375
+ qk_norm=qk_norm,
376
+ eps=1e-6,
377
+ pre_only=True,
378
+ )
379
+
380
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
381
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
382
+ self.act_mlp = nn.GELU(approximate="tanh")
383
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ encoder_hidden_states: torch.Tensor,
389
+ temb: torch.Tensor,
390
+ attention_mask: Optional[torch.Tensor] = None,
391
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
392
+ ) -> torch.Tensor:
393
+ text_seq_length = encoder_hidden_states.shape[1]
394
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
395
+
396
+ residual = hidden_states
397
+
398
+ # 1. Input normalization
399
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
400
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
401
+
402
+ norm_hidden_states, norm_encoder_hidden_states = (
403
+ norm_hidden_states[:, :-text_seq_length, :],
404
+ norm_hidden_states[:, -text_seq_length:, :],
405
+ )
406
+
407
+ # 2. Attention
408
+ attn_output, context_attn_output = self.attn(
409
+ hidden_states=norm_hidden_states,
410
+ encoder_hidden_states=norm_encoder_hidden_states,
411
+ attention_mask=attention_mask,
412
+ image_rotary_emb=image_rotary_emb,
413
+ )
414
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
415
+
416
+ # 3. Modulation and residual connection
417
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
418
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
419
+ hidden_states = hidden_states + residual
420
+
421
+ hidden_states, encoder_hidden_states = (
422
+ hidden_states[:, :-text_seq_length, :],
423
+ hidden_states[:, -text_seq_length:, :],
424
+ )
425
+ return hidden_states, encoder_hidden_states
426
+
427
+
428
+ class HunyuanVideoTransformerBlock(nn.Module):
429
+ def __init__(
430
+ self,
431
+ num_attention_heads: int,
432
+ attention_head_dim: int,
433
+ mlp_ratio: float,
434
+ qk_norm: str = "rms_norm",
435
+ ) -> None:
436
+ super().__init__()
437
+
438
+ hidden_size = num_attention_heads * attention_head_dim
439
+
440
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
441
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
442
+
443
+ self.attn = Attention(
444
+ query_dim=hidden_size,
445
+ cross_attention_dim=None,
446
+ added_kv_proj_dim=hidden_size,
447
+ dim_head=attention_head_dim,
448
+ heads=num_attention_heads,
449
+ out_dim=hidden_size,
450
+ context_pre_only=False,
451
+ bias=True,
452
+ processor=HunyuanVideoAttnProcessor2_0(),
453
+ qk_norm=qk_norm,
454
+ eps=1e-6,
455
+ )
456
+
457
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
458
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
459
+
460
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
461
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ encoder_hidden_states: torch.Tensor,
467
+ temb: torch.Tensor,
468
+ attention_mask: Optional[torch.Tensor] = None,
469
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
470
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
471
+ # 1. Input normalization
472
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
473
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
474
+ encoder_hidden_states, emb=temb
475
+ )
476
+
477
+ # 2. Joint attention
478
+ attn_output, context_attn_output = self.attn(
479
+ hidden_states=norm_hidden_states,
480
+ encoder_hidden_states=norm_encoder_hidden_states,
481
+ attention_mask=attention_mask,
482
+ image_rotary_emb=freqs_cis,
483
+ )
484
+
485
+ # 3. Modulation and residual connection
486
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
487
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
488
+
489
+ norm_hidden_states = self.norm2(hidden_states)
490
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
491
+
492
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
493
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
494
+
495
+ # 4. Feed-forward
496
+ ff_output = self.ff(norm_hidden_states)
497
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
498
+
499
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
500
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
501
+
502
+ return hidden_states, encoder_hidden_states
503
+
504
+
505
+ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
506
+ r"""
507
+ A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
508
+
509
+ Args:
510
+ in_channels (`int`, defaults to `16`):
511
+ The number of channels in the input.
512
+ out_channels (`int`, defaults to `16`):
513
+ The number of channels in the output.
514
+ num_attention_heads (`int`, defaults to `24`):
515
+ The number of heads to use for multi-head attention.
516
+ attention_head_dim (`int`, defaults to `128`):
517
+ The number of channels in each head.
518
+ num_layers (`int`, defaults to `20`):
519
+ The number of layers of dual-stream blocks to use.
520
+ num_single_layers (`int`, defaults to `40`):
521
+ The number of layers of single-stream blocks to use.
522
+ num_refiner_layers (`int`, defaults to `2`):
523
+ The number of layers of refiner blocks to use.
524
+ mlp_ratio (`float`, defaults to `4.0`):
525
+ The ratio of the hidden layer size to the input size in the feedforward network.
526
+ patch_size (`int`, defaults to `2`):
527
+ The size of the spatial patches to use in the patch embedding layer.
528
+ patch_size_t (`int`, defaults to `1`):
529
+ The size of the tmeporal patches to use in the patch embedding layer.
530
+ qk_norm (`str`, defaults to `rms_norm`):
531
+ The normalization to use for the query and key projections in the attention layers.
532
+ guidance_embeds (`bool`, defaults to `True`):
533
+ Whether to use guidance embeddings in the model.
534
+ text_embed_dim (`int`, defaults to `4096`):
535
+ Input dimension of text embeddings from the text encoder.
536
+ pooled_projection_dim (`int`, defaults to `768`):
537
+ The dimension of the pooled projection of the text embeddings.
538
+ rope_theta (`float`, defaults to `256.0`):
539
+ The value of theta to use in the RoPE layer.
540
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
541
+ The dimensions of the axes to use in the RoPE layer.
542
+ """
543
+
544
+ _supports_gradient_checkpointing = True
545
+
546
+ @register_to_config
547
+ def __init__(
548
+ self,
549
+ in_channels: int = 16,
550
+ out_channels: int = 16,
551
+ num_attention_heads: int = 24,
552
+ attention_head_dim: int = 128,
553
+ num_layers: int = 20,
554
+ num_single_layers: int = 40,
555
+ num_refiner_layers: int = 2,
556
+ mlp_ratio: float = 4.0,
557
+ patch_size: int = 2,
558
+ patch_size_t: int = 1,
559
+ qk_norm: str = "rms_norm",
560
+ guidance_embeds: bool = True,
561
+ text_embed_dim: int = 4096,
562
+ pooled_projection_dim: int = 768,
563
+ rope_theta: float = 256.0,
564
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
565
+ ) -> None:
566
+ super().__init__()
567
+
568
+ inner_dim = num_attention_heads * attention_head_dim
569
+ out_channels = out_channels or in_channels
570
+
571
+ # 1. Latent and condition embedders
572
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
573
+ self.context_embedder = HunyuanVideoTokenRefiner(
574
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
575
+ )
576
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
577
+
578
+ # 2. RoPE
579
+ self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
580
+
581
+ # 3. Dual stream transformer blocks
582
+ self.transformer_blocks = nn.ModuleList(
583
+ [
584
+ HunyuanVideoTransformerBlock(
585
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
586
+ )
587
+ for _ in range(num_layers)
588
+ ]
589
+ )
590
+
591
+ # 4. Single stream transformer blocks
592
+ self.single_transformer_blocks = nn.ModuleList(
593
+ [
594
+ HunyuanVideoSingleTransformerBlock(
595
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
596
+ )
597
+ for _ in range(num_single_layers)
598
+ ]
599
+ )
600
+
601
+ # 5. Output projection
602
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
603
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
604
+
605
+ self.gradient_checkpointing = False
606
+
607
+ @property
608
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
609
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
610
+ r"""
611
+ Returns:
612
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
613
+ indexed by its weight name.
614
+ """
615
+ # set recursively
616
+ processors = {}
617
+
618
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
619
+ if hasattr(module, "get_processor"):
620
+ processors[f"{name}.processor"] = module.get_processor()
621
+
622
+ for sub_name, child in module.named_children():
623
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
624
+
625
+ return processors
626
+
627
+ for name, module in self.named_children():
628
+ fn_recursive_add_processors(name, module, processors)
629
+
630
+ return processors
631
+
632
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
633
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
634
+ r"""
635
+ Sets the attention processor to use to compute attention.
636
+
637
+ Parameters:
638
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
639
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
640
+ for **all** `Attention` layers.
641
+
642
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
643
+ processor. This is strongly recommended when setting trainable attention processors.
644
+
645
+ """
646
+ count = len(self.attn_processors.keys())
647
+
648
+ if isinstance(processor, dict) and len(processor) != count:
649
+ raise ValueError(
650
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
651
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
652
+ )
653
+
654
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
655
+ if hasattr(module, "set_processor"):
656
+ if not isinstance(processor, dict):
657
+ module.set_processor(processor)
658
+ else:
659
+ module.set_processor(processor.pop(f"{name}.processor"))
660
+
661
+ for sub_name, child in module.named_children():
662
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
663
+
664
+ for name, module in self.named_children():
665
+ fn_recursive_attn_processor(name, module, processor)
666
+
667
+ def _set_gradient_checkpointing(self, module, value=False):
668
+ if hasattr(module, "gradient_checkpointing"):
669
+ module.gradient_checkpointing = value
670
+
671
+ def forward(
672
+ self,
673
+ hidden_states: torch.Tensor,
674
+ timestep: torch.LongTensor,
675
+ encoder_hidden_states: torch.Tensor,
676
+ encoder_attention_mask: torch.Tensor,
677
+ pooled_projections: torch.Tensor,
678
+ guidance: torch.Tensor = None,
679
+ attention_kwargs: Optional[Dict[str, Any]] = None,
680
+ return_dict: bool = True,
681
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
682
+ if attention_kwargs is not None:
683
+ attention_kwargs = attention_kwargs.copy()
684
+ lora_scale = attention_kwargs.pop("scale", 1.0)
685
+ else:
686
+ lora_scale = 1.0
687
+
688
+ if USE_PEFT_BACKEND:
689
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
690
+ scale_lora_layers(self, lora_scale)
691
+ else:
692
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
693
+ logger.warning(
694
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
695
+ )
696
+
697
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
698
+ p, p_t = self.config.patch_size, self.config.patch_size_t
699
+ post_patch_num_frames = num_frames // p_t
700
+ post_patch_height = height // p
701
+ post_patch_width = width // p
702
+
703
+ # 1. RoPE
704
+ image_rotary_emb = self.rope(hidden_states)
705
+
706
+ # 2. Conditional embeddings
707
+ temb = self.time_text_embed(timestep, guidance, pooled_projections)
708
+ hidden_states = self.x_embedder(hidden_states)
709
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
710
+
711
+ # 3. Attention mask preparation
712
+ latent_sequence_length = hidden_states.shape[1]
713
+ condition_sequence_length = encoder_hidden_states.shape[1]
714
+ sequence_length = latent_sequence_length + condition_sequence_length
715
+ attention_mask = torch.zeros(
716
+ batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
717
+ ) # [B, N, N]
718
+
719
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
720
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
721
+
722
+ for i in range(batch_size):
723
+ attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
724
+
725
+ # 4. Transformer blocks
726
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
727
+
728
+ def create_custom_forward(module, return_dict=None):
729
+ def custom_forward(*inputs):
730
+ if return_dict is not None:
731
+ return module(*inputs, return_dict=return_dict)
732
+ else:
733
+ return module(*inputs)
734
+
735
+ return custom_forward
736
+
737
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
738
+
739
+ for block in self.transformer_blocks:
740
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
741
+ create_custom_forward(block),
742
+ hidden_states,
743
+ encoder_hidden_states,
744
+ temb,
745
+ attention_mask,
746
+ image_rotary_emb,
747
+ **ckpt_kwargs,
748
+ )
749
+
750
+ for block in self.single_transformer_blocks:
751
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
752
+ create_custom_forward(block),
753
+ hidden_states,
754
+ encoder_hidden_states,
755
+ temb,
756
+ attention_mask,
757
+ image_rotary_emb,
758
+ **ckpt_kwargs,
759
+ )
760
+
761
+ else:
762
+ for block in self.transformer_blocks:
763
+ hidden_states, encoder_hidden_states = block(
764
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
765
+ )
766
+
767
+ for block in self.single_transformer_blocks:
768
+ hidden_states, encoder_hidden_states = block(
769
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
770
+ )
771
+
772
+ # 5. Output projection
773
+ hidden_states = self.norm_out(hidden_states, temb)
774
+ hidden_states = self.proj_out(hidden_states)
775
+
776
+ hidden_states = hidden_states.reshape(
777
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
778
+ )
779
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
780
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
781
+
782
+ if USE_PEFT_BACKEND:
783
+ # remove `lora_scale` from each PEFT layer
784
+ unscale_lora_layers(self, lora_scale)
785
+
786
+ if not return_dict:
787
+ return (hidden_states,)
788
+
789
+ return Transformer2DModelOutput(sample=hidden_states)