diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,789 @@
|
|
1
|
+
# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
import torch.nn.functional as F
|
20
|
+
|
21
|
+
from diffusers.loaders import FromOriginalModelMixin
|
22
|
+
|
23
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
+
from ...loaders import PeftAdapterMixin
|
25
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
26
|
+
from ..attention import FeedForward
|
27
|
+
from ..attention_processor import Attention, AttentionProcessor
|
28
|
+
from ..embeddings import (
|
29
|
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
30
|
+
CombinedTimestepTextProjEmbeddings,
|
31
|
+
get_1d_rotary_pos_embed,
|
32
|
+
)
|
33
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
34
|
+
from ..modeling_utils import ModelMixin
|
35
|
+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
36
|
+
|
37
|
+
|
38
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39
|
+
|
40
|
+
|
41
|
+
class HunyuanVideoAttnProcessor2_0:
|
42
|
+
def __init__(self):
|
43
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
44
|
+
raise ImportError(
|
45
|
+
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
46
|
+
)
|
47
|
+
|
48
|
+
def __call__(
|
49
|
+
self,
|
50
|
+
attn: Attention,
|
51
|
+
hidden_states: torch.Tensor,
|
52
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
53
|
+
attention_mask: Optional[torch.Tensor] = None,
|
54
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
55
|
+
) -> torch.Tensor:
|
56
|
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
57
|
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
58
|
+
|
59
|
+
# 1. QKV projections
|
60
|
+
query = attn.to_q(hidden_states)
|
61
|
+
key = attn.to_k(hidden_states)
|
62
|
+
value = attn.to_v(hidden_states)
|
63
|
+
|
64
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
65
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
66
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
67
|
+
|
68
|
+
# 2. QK normalization
|
69
|
+
if attn.norm_q is not None:
|
70
|
+
query = attn.norm_q(query)
|
71
|
+
if attn.norm_k is not None:
|
72
|
+
key = attn.norm_k(key)
|
73
|
+
|
74
|
+
# 3. Rotational positional embeddings applied to latent stream
|
75
|
+
if image_rotary_emb is not None:
|
76
|
+
from ..embeddings import apply_rotary_emb
|
77
|
+
|
78
|
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
79
|
+
query = torch.cat(
|
80
|
+
[
|
81
|
+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
82
|
+
query[:, :, -encoder_hidden_states.shape[1] :],
|
83
|
+
],
|
84
|
+
dim=2,
|
85
|
+
)
|
86
|
+
key = torch.cat(
|
87
|
+
[
|
88
|
+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
89
|
+
key[:, :, -encoder_hidden_states.shape[1] :],
|
90
|
+
],
|
91
|
+
dim=2,
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
95
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
96
|
+
|
97
|
+
# 4. Encoder condition QKV projection and normalization
|
98
|
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
99
|
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
100
|
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
101
|
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
102
|
+
|
103
|
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
104
|
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
105
|
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
106
|
+
|
107
|
+
if attn.norm_added_q is not None:
|
108
|
+
encoder_query = attn.norm_added_q(encoder_query)
|
109
|
+
if attn.norm_added_k is not None:
|
110
|
+
encoder_key = attn.norm_added_k(encoder_key)
|
111
|
+
|
112
|
+
query = torch.cat([query, encoder_query], dim=2)
|
113
|
+
key = torch.cat([key, encoder_key], dim=2)
|
114
|
+
value = torch.cat([value, encoder_value], dim=2)
|
115
|
+
|
116
|
+
# 5. Attention
|
117
|
+
hidden_states = F.scaled_dot_product_attention(
|
118
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
119
|
+
)
|
120
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
121
|
+
hidden_states = hidden_states.to(query.dtype)
|
122
|
+
|
123
|
+
# 6. Output projection
|
124
|
+
if encoder_hidden_states is not None:
|
125
|
+
hidden_states, encoder_hidden_states = (
|
126
|
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
127
|
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
128
|
+
)
|
129
|
+
|
130
|
+
if getattr(attn, "to_out", None) is not None:
|
131
|
+
hidden_states = attn.to_out[0](hidden_states)
|
132
|
+
hidden_states = attn.to_out[1](hidden_states)
|
133
|
+
|
134
|
+
if getattr(attn, "to_add_out", None) is not None:
|
135
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
136
|
+
|
137
|
+
return hidden_states, encoder_hidden_states
|
138
|
+
|
139
|
+
|
140
|
+
class HunyuanVideoPatchEmbed(nn.Module):
|
141
|
+
def __init__(
|
142
|
+
self,
|
143
|
+
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
144
|
+
in_chans: int = 3,
|
145
|
+
embed_dim: int = 768,
|
146
|
+
) -> None:
|
147
|
+
super().__init__()
|
148
|
+
|
149
|
+
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
150
|
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
151
|
+
|
152
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
153
|
+
hidden_states = self.proj(hidden_states)
|
154
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
155
|
+
return hidden_states
|
156
|
+
|
157
|
+
|
158
|
+
class HunyuanVideoAdaNorm(nn.Module):
|
159
|
+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
160
|
+
super().__init__()
|
161
|
+
|
162
|
+
out_features = out_features or 2 * in_features
|
163
|
+
self.linear = nn.Linear(in_features, out_features)
|
164
|
+
self.nonlinearity = nn.SiLU()
|
165
|
+
|
166
|
+
def forward(
|
167
|
+
self, temb: torch.Tensor
|
168
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
169
|
+
temb = self.linear(self.nonlinearity(temb))
|
170
|
+
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
171
|
+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
172
|
+
return gate_msa, gate_mlp
|
173
|
+
|
174
|
+
|
175
|
+
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
176
|
+
def __init__(
|
177
|
+
self,
|
178
|
+
num_attention_heads: int,
|
179
|
+
attention_head_dim: int,
|
180
|
+
mlp_width_ratio: str = 4.0,
|
181
|
+
mlp_drop_rate: float = 0.0,
|
182
|
+
attention_bias: bool = True,
|
183
|
+
) -> None:
|
184
|
+
super().__init__()
|
185
|
+
|
186
|
+
hidden_size = num_attention_heads * attention_head_dim
|
187
|
+
|
188
|
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
189
|
+
self.attn = Attention(
|
190
|
+
query_dim=hidden_size,
|
191
|
+
cross_attention_dim=None,
|
192
|
+
heads=num_attention_heads,
|
193
|
+
dim_head=attention_head_dim,
|
194
|
+
bias=attention_bias,
|
195
|
+
)
|
196
|
+
|
197
|
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
198
|
+
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
199
|
+
|
200
|
+
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
201
|
+
|
202
|
+
def forward(
|
203
|
+
self,
|
204
|
+
hidden_states: torch.Tensor,
|
205
|
+
temb: torch.Tensor,
|
206
|
+
attention_mask: Optional[torch.Tensor] = None,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
norm_hidden_states = self.norm1(hidden_states)
|
209
|
+
|
210
|
+
attn_output = self.attn(
|
211
|
+
hidden_states=norm_hidden_states,
|
212
|
+
encoder_hidden_states=None,
|
213
|
+
attention_mask=attention_mask,
|
214
|
+
)
|
215
|
+
|
216
|
+
gate_msa, gate_mlp = self.norm_out(temb)
|
217
|
+
hidden_states = hidden_states + attn_output * gate_msa
|
218
|
+
|
219
|
+
ff_output = self.ff(self.norm2(hidden_states))
|
220
|
+
hidden_states = hidden_states + ff_output * gate_mlp
|
221
|
+
|
222
|
+
return hidden_states
|
223
|
+
|
224
|
+
|
225
|
+
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
num_attention_heads: int,
|
229
|
+
attention_head_dim: int,
|
230
|
+
num_layers: int,
|
231
|
+
mlp_width_ratio: float = 4.0,
|
232
|
+
mlp_drop_rate: float = 0.0,
|
233
|
+
attention_bias: bool = True,
|
234
|
+
) -> None:
|
235
|
+
super().__init__()
|
236
|
+
|
237
|
+
self.refiner_blocks = nn.ModuleList(
|
238
|
+
[
|
239
|
+
HunyuanVideoIndividualTokenRefinerBlock(
|
240
|
+
num_attention_heads=num_attention_heads,
|
241
|
+
attention_head_dim=attention_head_dim,
|
242
|
+
mlp_width_ratio=mlp_width_ratio,
|
243
|
+
mlp_drop_rate=mlp_drop_rate,
|
244
|
+
attention_bias=attention_bias,
|
245
|
+
)
|
246
|
+
for _ in range(num_layers)
|
247
|
+
]
|
248
|
+
)
|
249
|
+
|
250
|
+
def forward(
|
251
|
+
self,
|
252
|
+
hidden_states: torch.Tensor,
|
253
|
+
temb: torch.Tensor,
|
254
|
+
attention_mask: Optional[torch.Tensor] = None,
|
255
|
+
) -> None:
|
256
|
+
self_attn_mask = None
|
257
|
+
if attention_mask is not None:
|
258
|
+
batch_size = attention_mask.shape[0]
|
259
|
+
seq_len = attention_mask.shape[1]
|
260
|
+
attention_mask = attention_mask.to(hidden_states.device).bool()
|
261
|
+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
262
|
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
263
|
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
264
|
+
self_attn_mask[:, :, :, 0] = True
|
265
|
+
|
266
|
+
for block in self.refiner_blocks:
|
267
|
+
hidden_states = block(hidden_states, temb, self_attn_mask)
|
268
|
+
|
269
|
+
return hidden_states
|
270
|
+
|
271
|
+
|
272
|
+
class HunyuanVideoTokenRefiner(nn.Module):
|
273
|
+
def __init__(
|
274
|
+
self,
|
275
|
+
in_channels: int,
|
276
|
+
num_attention_heads: int,
|
277
|
+
attention_head_dim: int,
|
278
|
+
num_layers: int,
|
279
|
+
mlp_ratio: float = 4.0,
|
280
|
+
mlp_drop_rate: float = 0.0,
|
281
|
+
attention_bias: bool = True,
|
282
|
+
) -> None:
|
283
|
+
super().__init__()
|
284
|
+
|
285
|
+
hidden_size = num_attention_heads * attention_head_dim
|
286
|
+
|
287
|
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
288
|
+
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
289
|
+
)
|
290
|
+
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
291
|
+
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
292
|
+
num_attention_heads=num_attention_heads,
|
293
|
+
attention_head_dim=attention_head_dim,
|
294
|
+
num_layers=num_layers,
|
295
|
+
mlp_width_ratio=mlp_ratio,
|
296
|
+
mlp_drop_rate=mlp_drop_rate,
|
297
|
+
attention_bias=attention_bias,
|
298
|
+
)
|
299
|
+
|
300
|
+
def forward(
|
301
|
+
self,
|
302
|
+
hidden_states: torch.Tensor,
|
303
|
+
timestep: torch.LongTensor,
|
304
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
305
|
+
) -> torch.Tensor:
|
306
|
+
if attention_mask is None:
|
307
|
+
pooled_projections = hidden_states.mean(dim=1)
|
308
|
+
else:
|
309
|
+
original_dtype = hidden_states.dtype
|
310
|
+
mask_float = attention_mask.float().unsqueeze(-1)
|
311
|
+
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
312
|
+
pooled_projections = pooled_projections.to(original_dtype)
|
313
|
+
|
314
|
+
temb = self.time_text_embed(timestep, pooled_projections)
|
315
|
+
hidden_states = self.proj_in(hidden_states)
|
316
|
+
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
317
|
+
|
318
|
+
return hidden_states
|
319
|
+
|
320
|
+
|
321
|
+
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
322
|
+
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
323
|
+
super().__init__()
|
324
|
+
|
325
|
+
self.patch_size = patch_size
|
326
|
+
self.patch_size_t = patch_size_t
|
327
|
+
self.rope_dim = rope_dim
|
328
|
+
self.theta = theta
|
329
|
+
|
330
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
331
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
332
|
+
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
333
|
+
|
334
|
+
axes_grids = []
|
335
|
+
for i in range(3):
|
336
|
+
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
337
|
+
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
338
|
+
# differences in layerwise debugging outputs, but visually it is the same.
|
339
|
+
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
340
|
+
axes_grids.append(grid)
|
341
|
+
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
342
|
+
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
343
|
+
|
344
|
+
freqs = []
|
345
|
+
for i in range(3):
|
346
|
+
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
347
|
+
freqs.append(freq)
|
348
|
+
|
349
|
+
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
350
|
+
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
351
|
+
return freqs_cos, freqs_sin
|
352
|
+
|
353
|
+
|
354
|
+
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
355
|
+
def __init__(
|
356
|
+
self,
|
357
|
+
num_attention_heads: int,
|
358
|
+
attention_head_dim: int,
|
359
|
+
mlp_ratio: float = 4.0,
|
360
|
+
qk_norm: str = "rms_norm",
|
361
|
+
) -> None:
|
362
|
+
super().__init__()
|
363
|
+
|
364
|
+
hidden_size = num_attention_heads * attention_head_dim
|
365
|
+
mlp_dim = int(hidden_size * mlp_ratio)
|
366
|
+
|
367
|
+
self.attn = Attention(
|
368
|
+
query_dim=hidden_size,
|
369
|
+
cross_attention_dim=None,
|
370
|
+
dim_head=attention_head_dim,
|
371
|
+
heads=num_attention_heads,
|
372
|
+
out_dim=hidden_size,
|
373
|
+
bias=True,
|
374
|
+
processor=HunyuanVideoAttnProcessor2_0(),
|
375
|
+
qk_norm=qk_norm,
|
376
|
+
eps=1e-6,
|
377
|
+
pre_only=True,
|
378
|
+
)
|
379
|
+
|
380
|
+
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
381
|
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
382
|
+
self.act_mlp = nn.GELU(approximate="tanh")
|
383
|
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
384
|
+
|
385
|
+
def forward(
|
386
|
+
self,
|
387
|
+
hidden_states: torch.Tensor,
|
388
|
+
encoder_hidden_states: torch.Tensor,
|
389
|
+
temb: torch.Tensor,
|
390
|
+
attention_mask: Optional[torch.Tensor] = None,
|
391
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
392
|
+
) -> torch.Tensor:
|
393
|
+
text_seq_length = encoder_hidden_states.shape[1]
|
394
|
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
395
|
+
|
396
|
+
residual = hidden_states
|
397
|
+
|
398
|
+
# 1. Input normalization
|
399
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
400
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
401
|
+
|
402
|
+
norm_hidden_states, norm_encoder_hidden_states = (
|
403
|
+
norm_hidden_states[:, :-text_seq_length, :],
|
404
|
+
norm_hidden_states[:, -text_seq_length:, :],
|
405
|
+
)
|
406
|
+
|
407
|
+
# 2. Attention
|
408
|
+
attn_output, context_attn_output = self.attn(
|
409
|
+
hidden_states=norm_hidden_states,
|
410
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
411
|
+
attention_mask=attention_mask,
|
412
|
+
image_rotary_emb=image_rotary_emb,
|
413
|
+
)
|
414
|
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
415
|
+
|
416
|
+
# 3. Modulation and residual connection
|
417
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
418
|
+
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
419
|
+
hidden_states = hidden_states + residual
|
420
|
+
|
421
|
+
hidden_states, encoder_hidden_states = (
|
422
|
+
hidden_states[:, :-text_seq_length, :],
|
423
|
+
hidden_states[:, -text_seq_length:, :],
|
424
|
+
)
|
425
|
+
return hidden_states, encoder_hidden_states
|
426
|
+
|
427
|
+
|
428
|
+
class HunyuanVideoTransformerBlock(nn.Module):
|
429
|
+
def __init__(
|
430
|
+
self,
|
431
|
+
num_attention_heads: int,
|
432
|
+
attention_head_dim: int,
|
433
|
+
mlp_ratio: float,
|
434
|
+
qk_norm: str = "rms_norm",
|
435
|
+
) -> None:
|
436
|
+
super().__init__()
|
437
|
+
|
438
|
+
hidden_size = num_attention_heads * attention_head_dim
|
439
|
+
|
440
|
+
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
441
|
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
442
|
+
|
443
|
+
self.attn = Attention(
|
444
|
+
query_dim=hidden_size,
|
445
|
+
cross_attention_dim=None,
|
446
|
+
added_kv_proj_dim=hidden_size,
|
447
|
+
dim_head=attention_head_dim,
|
448
|
+
heads=num_attention_heads,
|
449
|
+
out_dim=hidden_size,
|
450
|
+
context_pre_only=False,
|
451
|
+
bias=True,
|
452
|
+
processor=HunyuanVideoAttnProcessor2_0(),
|
453
|
+
qk_norm=qk_norm,
|
454
|
+
eps=1e-6,
|
455
|
+
)
|
456
|
+
|
457
|
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
458
|
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
459
|
+
|
460
|
+
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
461
|
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
462
|
+
|
463
|
+
def forward(
|
464
|
+
self,
|
465
|
+
hidden_states: torch.Tensor,
|
466
|
+
encoder_hidden_states: torch.Tensor,
|
467
|
+
temb: torch.Tensor,
|
468
|
+
attention_mask: Optional[torch.Tensor] = None,
|
469
|
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
470
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
471
|
+
# 1. Input normalization
|
472
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
473
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
474
|
+
encoder_hidden_states, emb=temb
|
475
|
+
)
|
476
|
+
|
477
|
+
# 2. Joint attention
|
478
|
+
attn_output, context_attn_output = self.attn(
|
479
|
+
hidden_states=norm_hidden_states,
|
480
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
481
|
+
attention_mask=attention_mask,
|
482
|
+
image_rotary_emb=freqs_cis,
|
483
|
+
)
|
484
|
+
|
485
|
+
# 3. Modulation and residual connection
|
486
|
+
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
487
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
488
|
+
|
489
|
+
norm_hidden_states = self.norm2(hidden_states)
|
490
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
491
|
+
|
492
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
493
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
494
|
+
|
495
|
+
# 4. Feed-forward
|
496
|
+
ff_output = self.ff(norm_hidden_states)
|
497
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
498
|
+
|
499
|
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
500
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
501
|
+
|
502
|
+
return hidden_states, encoder_hidden_states
|
503
|
+
|
504
|
+
|
505
|
+
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
506
|
+
r"""
|
507
|
+
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
508
|
+
|
509
|
+
Args:
|
510
|
+
in_channels (`int`, defaults to `16`):
|
511
|
+
The number of channels in the input.
|
512
|
+
out_channels (`int`, defaults to `16`):
|
513
|
+
The number of channels in the output.
|
514
|
+
num_attention_heads (`int`, defaults to `24`):
|
515
|
+
The number of heads to use for multi-head attention.
|
516
|
+
attention_head_dim (`int`, defaults to `128`):
|
517
|
+
The number of channels in each head.
|
518
|
+
num_layers (`int`, defaults to `20`):
|
519
|
+
The number of layers of dual-stream blocks to use.
|
520
|
+
num_single_layers (`int`, defaults to `40`):
|
521
|
+
The number of layers of single-stream blocks to use.
|
522
|
+
num_refiner_layers (`int`, defaults to `2`):
|
523
|
+
The number of layers of refiner blocks to use.
|
524
|
+
mlp_ratio (`float`, defaults to `4.0`):
|
525
|
+
The ratio of the hidden layer size to the input size in the feedforward network.
|
526
|
+
patch_size (`int`, defaults to `2`):
|
527
|
+
The size of the spatial patches to use in the patch embedding layer.
|
528
|
+
patch_size_t (`int`, defaults to `1`):
|
529
|
+
The size of the tmeporal patches to use in the patch embedding layer.
|
530
|
+
qk_norm (`str`, defaults to `rms_norm`):
|
531
|
+
The normalization to use for the query and key projections in the attention layers.
|
532
|
+
guidance_embeds (`bool`, defaults to `True`):
|
533
|
+
Whether to use guidance embeddings in the model.
|
534
|
+
text_embed_dim (`int`, defaults to `4096`):
|
535
|
+
Input dimension of text embeddings from the text encoder.
|
536
|
+
pooled_projection_dim (`int`, defaults to `768`):
|
537
|
+
The dimension of the pooled projection of the text embeddings.
|
538
|
+
rope_theta (`float`, defaults to `256.0`):
|
539
|
+
The value of theta to use in the RoPE layer.
|
540
|
+
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
541
|
+
The dimensions of the axes to use in the RoPE layer.
|
542
|
+
"""
|
543
|
+
|
544
|
+
_supports_gradient_checkpointing = True
|
545
|
+
|
546
|
+
@register_to_config
|
547
|
+
def __init__(
|
548
|
+
self,
|
549
|
+
in_channels: int = 16,
|
550
|
+
out_channels: int = 16,
|
551
|
+
num_attention_heads: int = 24,
|
552
|
+
attention_head_dim: int = 128,
|
553
|
+
num_layers: int = 20,
|
554
|
+
num_single_layers: int = 40,
|
555
|
+
num_refiner_layers: int = 2,
|
556
|
+
mlp_ratio: float = 4.0,
|
557
|
+
patch_size: int = 2,
|
558
|
+
patch_size_t: int = 1,
|
559
|
+
qk_norm: str = "rms_norm",
|
560
|
+
guidance_embeds: bool = True,
|
561
|
+
text_embed_dim: int = 4096,
|
562
|
+
pooled_projection_dim: int = 768,
|
563
|
+
rope_theta: float = 256.0,
|
564
|
+
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
565
|
+
) -> None:
|
566
|
+
super().__init__()
|
567
|
+
|
568
|
+
inner_dim = num_attention_heads * attention_head_dim
|
569
|
+
out_channels = out_channels or in_channels
|
570
|
+
|
571
|
+
# 1. Latent and condition embedders
|
572
|
+
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
573
|
+
self.context_embedder = HunyuanVideoTokenRefiner(
|
574
|
+
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
575
|
+
)
|
576
|
+
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
577
|
+
|
578
|
+
# 2. RoPE
|
579
|
+
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
580
|
+
|
581
|
+
# 3. Dual stream transformer blocks
|
582
|
+
self.transformer_blocks = nn.ModuleList(
|
583
|
+
[
|
584
|
+
HunyuanVideoTransformerBlock(
|
585
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
586
|
+
)
|
587
|
+
for _ in range(num_layers)
|
588
|
+
]
|
589
|
+
)
|
590
|
+
|
591
|
+
# 4. Single stream transformer blocks
|
592
|
+
self.single_transformer_blocks = nn.ModuleList(
|
593
|
+
[
|
594
|
+
HunyuanVideoSingleTransformerBlock(
|
595
|
+
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
596
|
+
)
|
597
|
+
for _ in range(num_single_layers)
|
598
|
+
]
|
599
|
+
)
|
600
|
+
|
601
|
+
# 5. Output projection
|
602
|
+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
603
|
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
604
|
+
|
605
|
+
self.gradient_checkpointing = False
|
606
|
+
|
607
|
+
@property
|
608
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
609
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
610
|
+
r"""
|
611
|
+
Returns:
|
612
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
613
|
+
indexed by its weight name.
|
614
|
+
"""
|
615
|
+
# set recursively
|
616
|
+
processors = {}
|
617
|
+
|
618
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
619
|
+
if hasattr(module, "get_processor"):
|
620
|
+
processors[f"{name}.processor"] = module.get_processor()
|
621
|
+
|
622
|
+
for sub_name, child in module.named_children():
|
623
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
624
|
+
|
625
|
+
return processors
|
626
|
+
|
627
|
+
for name, module in self.named_children():
|
628
|
+
fn_recursive_add_processors(name, module, processors)
|
629
|
+
|
630
|
+
return processors
|
631
|
+
|
632
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
633
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
634
|
+
r"""
|
635
|
+
Sets the attention processor to use to compute attention.
|
636
|
+
|
637
|
+
Parameters:
|
638
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
639
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
640
|
+
for **all** `Attention` layers.
|
641
|
+
|
642
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
643
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
644
|
+
|
645
|
+
"""
|
646
|
+
count = len(self.attn_processors.keys())
|
647
|
+
|
648
|
+
if isinstance(processor, dict) and len(processor) != count:
|
649
|
+
raise ValueError(
|
650
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
651
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
652
|
+
)
|
653
|
+
|
654
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
655
|
+
if hasattr(module, "set_processor"):
|
656
|
+
if not isinstance(processor, dict):
|
657
|
+
module.set_processor(processor)
|
658
|
+
else:
|
659
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
660
|
+
|
661
|
+
for sub_name, child in module.named_children():
|
662
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
663
|
+
|
664
|
+
for name, module in self.named_children():
|
665
|
+
fn_recursive_attn_processor(name, module, processor)
|
666
|
+
|
667
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
668
|
+
if hasattr(module, "gradient_checkpointing"):
|
669
|
+
module.gradient_checkpointing = value
|
670
|
+
|
671
|
+
def forward(
|
672
|
+
self,
|
673
|
+
hidden_states: torch.Tensor,
|
674
|
+
timestep: torch.LongTensor,
|
675
|
+
encoder_hidden_states: torch.Tensor,
|
676
|
+
encoder_attention_mask: torch.Tensor,
|
677
|
+
pooled_projections: torch.Tensor,
|
678
|
+
guidance: torch.Tensor = None,
|
679
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
680
|
+
return_dict: bool = True,
|
681
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
682
|
+
if attention_kwargs is not None:
|
683
|
+
attention_kwargs = attention_kwargs.copy()
|
684
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
685
|
+
else:
|
686
|
+
lora_scale = 1.0
|
687
|
+
|
688
|
+
if USE_PEFT_BACKEND:
|
689
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
690
|
+
scale_lora_layers(self, lora_scale)
|
691
|
+
else:
|
692
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
693
|
+
logger.warning(
|
694
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
695
|
+
)
|
696
|
+
|
697
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
698
|
+
p, p_t = self.config.patch_size, self.config.patch_size_t
|
699
|
+
post_patch_num_frames = num_frames // p_t
|
700
|
+
post_patch_height = height // p
|
701
|
+
post_patch_width = width // p
|
702
|
+
|
703
|
+
# 1. RoPE
|
704
|
+
image_rotary_emb = self.rope(hidden_states)
|
705
|
+
|
706
|
+
# 2. Conditional embeddings
|
707
|
+
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
708
|
+
hidden_states = self.x_embedder(hidden_states)
|
709
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
710
|
+
|
711
|
+
# 3. Attention mask preparation
|
712
|
+
latent_sequence_length = hidden_states.shape[1]
|
713
|
+
condition_sequence_length = encoder_hidden_states.shape[1]
|
714
|
+
sequence_length = latent_sequence_length + condition_sequence_length
|
715
|
+
attention_mask = torch.zeros(
|
716
|
+
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
|
717
|
+
) # [B, N, N]
|
718
|
+
|
719
|
+
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
720
|
+
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
721
|
+
|
722
|
+
for i in range(batch_size):
|
723
|
+
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
|
724
|
+
|
725
|
+
# 4. Transformer blocks
|
726
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
727
|
+
|
728
|
+
def create_custom_forward(module, return_dict=None):
|
729
|
+
def custom_forward(*inputs):
|
730
|
+
if return_dict is not None:
|
731
|
+
return module(*inputs, return_dict=return_dict)
|
732
|
+
else:
|
733
|
+
return module(*inputs)
|
734
|
+
|
735
|
+
return custom_forward
|
736
|
+
|
737
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
738
|
+
|
739
|
+
for block in self.transformer_blocks:
|
740
|
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
741
|
+
create_custom_forward(block),
|
742
|
+
hidden_states,
|
743
|
+
encoder_hidden_states,
|
744
|
+
temb,
|
745
|
+
attention_mask,
|
746
|
+
image_rotary_emb,
|
747
|
+
**ckpt_kwargs,
|
748
|
+
)
|
749
|
+
|
750
|
+
for block in self.single_transformer_blocks:
|
751
|
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
752
|
+
create_custom_forward(block),
|
753
|
+
hidden_states,
|
754
|
+
encoder_hidden_states,
|
755
|
+
temb,
|
756
|
+
attention_mask,
|
757
|
+
image_rotary_emb,
|
758
|
+
**ckpt_kwargs,
|
759
|
+
)
|
760
|
+
|
761
|
+
else:
|
762
|
+
for block in self.transformer_blocks:
|
763
|
+
hidden_states, encoder_hidden_states = block(
|
764
|
+
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
765
|
+
)
|
766
|
+
|
767
|
+
for block in self.single_transformer_blocks:
|
768
|
+
hidden_states, encoder_hidden_states = block(
|
769
|
+
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
770
|
+
)
|
771
|
+
|
772
|
+
# 5. Output projection
|
773
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
774
|
+
hidden_states = self.proj_out(hidden_states)
|
775
|
+
|
776
|
+
hidden_states = hidden_states.reshape(
|
777
|
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
778
|
+
)
|
779
|
+
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
780
|
+
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
781
|
+
|
782
|
+
if USE_PEFT_BACKEND:
|
783
|
+
# remove `lora_scale` from each PEFT layer
|
784
|
+
unscale_lora_layers(self, lora_scale)
|
785
|
+
|
786
|
+
if not return_dict:
|
787
|
+
return (hidden_states,)
|
788
|
+
|
789
|
+
return Transformer2DModelOutput(sample=hidden_states)
|