diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,401 @@
|
|
1
|
+
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from typing import Dict, Optional, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from torch import nn
|
19
|
+
|
20
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
21
|
+
from ..utils import logging
|
22
|
+
from .attention_processor import AttentionProcessor
|
23
|
+
from .controlnet import BaseOutput, Tuple, zero_module
|
24
|
+
from .embeddings import (
|
25
|
+
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
26
|
+
PatchEmbed,
|
27
|
+
PixArtAlphaTextProjection,
|
28
|
+
)
|
29
|
+
from .modeling_utils import ModelMixin
|
30
|
+
from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class HunyuanControlNetOutput(BaseOutput):
|
38
|
+
controlnet_block_samples: Tuple[torch.Tensor]
|
39
|
+
|
40
|
+
|
41
|
+
class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
42
|
+
@register_to_config
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
conditioning_channels: int = 3,
|
46
|
+
num_attention_heads: int = 16,
|
47
|
+
attention_head_dim: int = 88,
|
48
|
+
in_channels: Optional[int] = None,
|
49
|
+
patch_size: Optional[int] = None,
|
50
|
+
activation_fn: str = "gelu-approximate",
|
51
|
+
sample_size=32,
|
52
|
+
hidden_size=1152,
|
53
|
+
transformer_num_layers: int = 40,
|
54
|
+
mlp_ratio: float = 4.0,
|
55
|
+
cross_attention_dim: int = 1024,
|
56
|
+
cross_attention_dim_t5: int = 2048,
|
57
|
+
pooled_projection_dim: int = 1024,
|
58
|
+
text_len: int = 77,
|
59
|
+
text_len_t5: int = 256,
|
60
|
+
use_style_cond_and_image_meta_size: bool = True,
|
61
|
+
):
|
62
|
+
super().__init__()
|
63
|
+
self.num_heads = num_attention_heads
|
64
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
65
|
+
|
66
|
+
self.text_embedder = PixArtAlphaTextProjection(
|
67
|
+
in_features=cross_attention_dim_t5,
|
68
|
+
hidden_size=cross_attention_dim_t5 * 4,
|
69
|
+
out_features=cross_attention_dim,
|
70
|
+
act_fn="silu_fp32",
|
71
|
+
)
|
72
|
+
|
73
|
+
self.text_embedding_padding = nn.Parameter(
|
74
|
+
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
75
|
+
)
|
76
|
+
|
77
|
+
self.pos_embed = PatchEmbed(
|
78
|
+
height=sample_size,
|
79
|
+
width=sample_size,
|
80
|
+
in_channels=in_channels,
|
81
|
+
embed_dim=hidden_size,
|
82
|
+
patch_size=patch_size,
|
83
|
+
pos_embed_type=None,
|
84
|
+
)
|
85
|
+
|
86
|
+
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
|
87
|
+
hidden_size,
|
88
|
+
pooled_projection_dim=pooled_projection_dim,
|
89
|
+
seq_len=text_len_t5,
|
90
|
+
cross_attention_dim=cross_attention_dim_t5,
|
91
|
+
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
|
92
|
+
)
|
93
|
+
|
94
|
+
# controlnet_blocks
|
95
|
+
self.controlnet_blocks = nn.ModuleList([])
|
96
|
+
|
97
|
+
# HunyuanDiT Blocks
|
98
|
+
self.blocks = nn.ModuleList(
|
99
|
+
[
|
100
|
+
HunyuanDiTBlock(
|
101
|
+
dim=self.inner_dim,
|
102
|
+
num_attention_heads=self.config.num_attention_heads,
|
103
|
+
activation_fn=activation_fn,
|
104
|
+
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
105
|
+
cross_attention_dim=cross_attention_dim,
|
106
|
+
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
107
|
+
skip=False, # always False as it is the first half of the model
|
108
|
+
)
|
109
|
+
for layer in range(transformer_num_layers // 2 - 1)
|
110
|
+
]
|
111
|
+
)
|
112
|
+
self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))
|
113
|
+
for _ in range(len(self.blocks)):
|
114
|
+
controlnet_block = nn.Linear(hidden_size, hidden_size)
|
115
|
+
controlnet_block = zero_module(controlnet_block)
|
116
|
+
self.controlnet_blocks.append(controlnet_block)
|
117
|
+
|
118
|
+
@property
|
119
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
120
|
+
r"""
|
121
|
+
Returns:
|
122
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
123
|
+
indexed by its weight name.
|
124
|
+
"""
|
125
|
+
# set recursively
|
126
|
+
processors = {}
|
127
|
+
|
128
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
129
|
+
if hasattr(module, "get_processor"):
|
130
|
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
131
|
+
|
132
|
+
for sub_name, child in module.named_children():
|
133
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
134
|
+
|
135
|
+
return processors
|
136
|
+
|
137
|
+
for name, module in self.named_children():
|
138
|
+
fn_recursive_add_processors(name, module, processors)
|
139
|
+
|
140
|
+
return processors
|
141
|
+
|
142
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
143
|
+
r"""
|
144
|
+
Sets the attention processor to use to compute attention.
|
145
|
+
|
146
|
+
Parameters:
|
147
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
148
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
149
|
+
for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
|
150
|
+
corresponding cross attention processor. This is strongly recommended when setting trainable attention
|
151
|
+
processors.
|
152
|
+
"""
|
153
|
+
count = len(self.attn_processors.keys())
|
154
|
+
|
155
|
+
if isinstance(processor, dict) and len(processor) != count:
|
156
|
+
raise ValueError(
|
157
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
158
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
159
|
+
)
|
160
|
+
|
161
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
162
|
+
if hasattr(module, "set_processor"):
|
163
|
+
if not isinstance(processor, dict):
|
164
|
+
module.set_processor(processor)
|
165
|
+
else:
|
166
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
167
|
+
|
168
|
+
for sub_name, child in module.named_children():
|
169
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
170
|
+
|
171
|
+
for name, module in self.named_children():
|
172
|
+
fn_recursive_attn_processor(name, module, processor)
|
173
|
+
|
174
|
+
@classmethod
|
175
|
+
def from_transformer(
|
176
|
+
cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
|
177
|
+
):
|
178
|
+
config = transformer.config
|
179
|
+
activation_fn = config.activation_fn
|
180
|
+
attention_head_dim = config.attention_head_dim
|
181
|
+
cross_attention_dim = config.cross_attention_dim
|
182
|
+
cross_attention_dim_t5 = config.cross_attention_dim_t5
|
183
|
+
hidden_size = config.hidden_size
|
184
|
+
in_channels = config.in_channels
|
185
|
+
mlp_ratio = config.mlp_ratio
|
186
|
+
num_attention_heads = config.num_attention_heads
|
187
|
+
patch_size = config.patch_size
|
188
|
+
sample_size = config.sample_size
|
189
|
+
text_len = config.text_len
|
190
|
+
text_len_t5 = config.text_len_t5
|
191
|
+
|
192
|
+
conditioning_channels = conditioning_channels
|
193
|
+
transformer_num_layers = transformer_num_layers or config.transformer_num_layers
|
194
|
+
|
195
|
+
controlnet = cls(
|
196
|
+
conditioning_channels=conditioning_channels,
|
197
|
+
transformer_num_layers=transformer_num_layers,
|
198
|
+
activation_fn=activation_fn,
|
199
|
+
attention_head_dim=attention_head_dim,
|
200
|
+
cross_attention_dim=cross_attention_dim,
|
201
|
+
cross_attention_dim_t5=cross_attention_dim_t5,
|
202
|
+
hidden_size=hidden_size,
|
203
|
+
in_channels=in_channels,
|
204
|
+
mlp_ratio=mlp_ratio,
|
205
|
+
num_attention_heads=num_attention_heads,
|
206
|
+
patch_size=patch_size,
|
207
|
+
sample_size=sample_size,
|
208
|
+
text_len=text_len,
|
209
|
+
text_len_t5=text_len_t5,
|
210
|
+
)
|
211
|
+
if load_weights_from_transformer:
|
212
|
+
key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
213
|
+
logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
|
214
|
+
return controlnet
|
215
|
+
|
216
|
+
def forward(
|
217
|
+
self,
|
218
|
+
hidden_states,
|
219
|
+
timestep,
|
220
|
+
controlnet_cond: torch.Tensor,
|
221
|
+
conditioning_scale: float = 1.0,
|
222
|
+
encoder_hidden_states=None,
|
223
|
+
text_embedding_mask=None,
|
224
|
+
encoder_hidden_states_t5=None,
|
225
|
+
text_embedding_mask_t5=None,
|
226
|
+
image_meta_size=None,
|
227
|
+
style=None,
|
228
|
+
image_rotary_emb=None,
|
229
|
+
return_dict=True,
|
230
|
+
):
|
231
|
+
"""
|
232
|
+
The [`HunyuanDiT2DControlNetModel`] forward method.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
236
|
+
The input tensor.
|
237
|
+
timestep ( `torch.LongTensor`, *optional*):
|
238
|
+
Used to indicate denoising step.
|
239
|
+
controlnet_cond ( `torch.Tensor` ):
|
240
|
+
The conditioning input to ControlNet.
|
241
|
+
conditioning_scale ( `float` ):
|
242
|
+
Indicate the conditioning scale.
|
243
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
244
|
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
245
|
+
text_embedding_mask: torch.Tensor
|
246
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
247
|
+
of `BertModel`.
|
248
|
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
249
|
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
250
|
+
text_embedding_mask_t5: torch.Tensor
|
251
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
252
|
+
of T5 Text Encoder.
|
253
|
+
image_meta_size (torch.Tensor):
|
254
|
+
Conditional embedding indicate the image sizes
|
255
|
+
style: torch.Tensor:
|
256
|
+
Conditional embedding indicate the style
|
257
|
+
image_rotary_emb (`torch.Tensor`):
|
258
|
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
259
|
+
return_dict: bool
|
260
|
+
Whether to return a dictionary.
|
261
|
+
"""
|
262
|
+
|
263
|
+
height, width = hidden_states.shape[-2:]
|
264
|
+
|
265
|
+
hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C
|
266
|
+
|
267
|
+
# 2. pre-process
|
268
|
+
hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))
|
269
|
+
|
270
|
+
temb = self.time_extra_emb(
|
271
|
+
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
|
272
|
+
) # [B, D]
|
273
|
+
|
274
|
+
# text projection
|
275
|
+
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
|
276
|
+
encoder_hidden_states_t5 = self.text_embedder(
|
277
|
+
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
|
278
|
+
)
|
279
|
+
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
|
280
|
+
|
281
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
|
282
|
+
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
|
283
|
+
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
|
284
|
+
|
285
|
+
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
|
286
|
+
|
287
|
+
block_res_samples = ()
|
288
|
+
for layer, block in enumerate(self.blocks):
|
289
|
+
hidden_states = block(
|
290
|
+
hidden_states,
|
291
|
+
temb=temb,
|
292
|
+
encoder_hidden_states=encoder_hidden_states,
|
293
|
+
image_rotary_emb=image_rotary_emb,
|
294
|
+
) # (N, L, D)
|
295
|
+
|
296
|
+
block_res_samples = block_res_samples + (hidden_states,)
|
297
|
+
|
298
|
+
controlnet_block_res_samples = ()
|
299
|
+
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
300
|
+
block_res_sample = controlnet_block(block_res_sample)
|
301
|
+
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
302
|
+
|
303
|
+
# 6. scaling
|
304
|
+
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
305
|
+
|
306
|
+
if not return_dict:
|
307
|
+
return (controlnet_block_res_samples,)
|
308
|
+
|
309
|
+
return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
310
|
+
|
311
|
+
|
312
|
+
class HunyuanDiT2DMultiControlNetModel(ModelMixin):
|
313
|
+
r"""
|
314
|
+
`HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel
|
315
|
+
|
316
|
+
This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is
|
317
|
+
designed to be compatible with `HunyuanDiT2DControlNetModel`.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
controlnets (`List[HunyuanDiT2DControlNetModel]`):
|
321
|
+
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
322
|
+
`HunyuanDiT2DControlNetModel` as a list.
|
323
|
+
"""
|
324
|
+
|
325
|
+
def __init__(self, controlnets):
|
326
|
+
super().__init__()
|
327
|
+
self.nets = nn.ModuleList(controlnets)
|
328
|
+
|
329
|
+
def forward(
|
330
|
+
self,
|
331
|
+
hidden_states,
|
332
|
+
timestep,
|
333
|
+
controlnet_cond: torch.Tensor,
|
334
|
+
conditioning_scale: float = 1.0,
|
335
|
+
encoder_hidden_states=None,
|
336
|
+
text_embedding_mask=None,
|
337
|
+
encoder_hidden_states_t5=None,
|
338
|
+
text_embedding_mask_t5=None,
|
339
|
+
image_meta_size=None,
|
340
|
+
style=None,
|
341
|
+
image_rotary_emb=None,
|
342
|
+
return_dict=True,
|
343
|
+
):
|
344
|
+
"""
|
345
|
+
The [`HunyuanDiT2DControlNetModel`] forward method.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
|
349
|
+
The input tensor.
|
350
|
+
timestep ( `torch.LongTensor`, *optional*):
|
351
|
+
Used to indicate denoising step.
|
352
|
+
controlnet_cond ( `torch.Tensor` ):
|
353
|
+
The conditioning input to ControlNet.
|
354
|
+
conditioning_scale ( `float` ):
|
355
|
+
Indicate the conditioning scale.
|
356
|
+
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
357
|
+
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
|
358
|
+
text_embedding_mask: torch.Tensor
|
359
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
360
|
+
of `BertModel`.
|
361
|
+
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
362
|
+
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
|
363
|
+
text_embedding_mask_t5: torch.Tensor
|
364
|
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
|
365
|
+
of T5 Text Encoder.
|
366
|
+
image_meta_size (torch.Tensor):
|
367
|
+
Conditional embedding indicate the image sizes
|
368
|
+
style: torch.Tensor:
|
369
|
+
Conditional embedding indicate the style
|
370
|
+
image_rotary_emb (`torch.Tensor`):
|
371
|
+
The image rotary embeddings to apply on query and key tensors during attention calculation.
|
372
|
+
return_dict: bool
|
373
|
+
Whether to return a dictionary.
|
374
|
+
"""
|
375
|
+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
376
|
+
block_samples = controlnet(
|
377
|
+
hidden_states=hidden_states,
|
378
|
+
timestep=timestep,
|
379
|
+
controlnet_cond=image,
|
380
|
+
conditioning_scale=scale,
|
381
|
+
encoder_hidden_states=encoder_hidden_states,
|
382
|
+
text_embedding_mask=text_embedding_mask,
|
383
|
+
encoder_hidden_states_t5=encoder_hidden_states_t5,
|
384
|
+
text_embedding_mask_t5=text_embedding_mask_t5,
|
385
|
+
image_meta_size=image_meta_size,
|
386
|
+
style=style,
|
387
|
+
image_rotary_emb=image_rotary_emb,
|
388
|
+
return_dict=return_dict,
|
389
|
+
)
|
390
|
+
|
391
|
+
# merge samples
|
392
|
+
if i == 0:
|
393
|
+
control_block_samples = block_samples
|
394
|
+
else:
|
395
|
+
control_block_samples = [
|
396
|
+
control_block_sample + block_sample
|
397
|
+
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
398
|
+
]
|
399
|
+
control_block_samples = (control_block_samples,)
|
400
|
+
|
401
|
+
return control_block_samples
|
@@ -22,7 +22,7 @@ import torch.nn as nn
|
|
22
22
|
from ..configuration_utils import ConfigMixin, register_to_config
|
23
23
|
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
24
|
from ..models.attention import JointTransformerBlock
|
25
|
-
from ..models.attention_processor import Attention, AttentionProcessor
|
25
|
+
from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
26
26
|
from ..models.modeling_outputs import Transformer2DModelOutput
|
27
27
|
from ..models.modeling_utils import ModelMixin
|
28
28
|
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
81
81
|
JointTransformerBlock(
|
82
82
|
dim=self.inner_dim,
|
83
83
|
num_attention_heads=num_attention_heads,
|
84
|
-
attention_head_dim=self.
|
84
|
+
attention_head_dim=self.config.attention_head_dim,
|
85
85
|
context_pre_only=False,
|
86
86
|
)
|
87
87
|
for i in range(num_layers)
|
@@ -149,7 +149,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
149
149
|
|
150
150
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
151
151
|
if hasattr(module, "get_processor"):
|
152
|
-
processors[f"{name}.processor"] = module.get_processor(
|
152
|
+
processors[f"{name}.processor"] = module.get_processor()
|
153
153
|
|
154
154
|
for sub_name, child in module.named_children():
|
155
155
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -196,7 +196,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
196
196
|
for name, module in self.named_children():
|
197
197
|
fn_recursive_attn_processor(name, module, processor)
|
198
198
|
|
199
|
-
# Copied from diffusers.models.
|
199
|
+
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
|
200
200
|
def fuse_qkv_projections(self):
|
201
201
|
"""
|
202
202
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
@@ -220,6 +220,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
220
220
|
if isinstance(module, Attention):
|
221
221
|
module.fuse_projections(fuse=True)
|
222
222
|
|
223
|
+
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
224
|
+
|
223
225
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
224
226
|
def unfuse_qkv_projections(self):
|
225
227
|
"""Disables the fused QKV projection if enabled.
|
@@ -239,16 +241,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
239
241
|
module.gradient_checkpointing = value
|
240
242
|
|
241
243
|
@classmethod
|
242
|
-
def from_transformer(cls, transformer, num_layers=
|
244
|
+
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
|
243
245
|
config = transformer.config
|
244
246
|
config["num_layers"] = num_layers or config.num_layers
|
245
247
|
controlnet = cls(**config)
|
246
248
|
|
247
249
|
if load_weights_from_transformer:
|
248
|
-
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()
|
249
|
-
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()
|
250
|
-
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()
|
251
|
-
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
|
250
|
+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
251
|
+
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
252
|
+
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
253
|
+
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
252
254
|
|
253
255
|
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
254
256
|
|
@@ -308,8 +310,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
308
310
|
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
309
311
|
)
|
310
312
|
|
311
|
-
height, width = hidden_states.shape[-2:]
|
312
|
-
|
313
313
|
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
314
314
|
temb = self.time_text_embed(timestep, pooled_projections)
|
315
315
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|