diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,401 @@
1
+ # Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Dict, Optional, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils import logging
22
+ from .attention_processor import AttentionProcessor
23
+ from .controlnet import BaseOutput, Tuple, zero_module
24
+ from .embeddings import (
25
+ HunyuanCombinedTimestepTextSizeStyleEmbedding,
26
+ PatchEmbed,
27
+ PixArtAlphaTextProjection,
28
+ )
29
+ from .modeling_utils import ModelMixin
30
+ from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class HunyuanControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
42
+ @register_to_config
43
+ def __init__(
44
+ self,
45
+ conditioning_channels: int = 3,
46
+ num_attention_heads: int = 16,
47
+ attention_head_dim: int = 88,
48
+ in_channels: Optional[int] = None,
49
+ patch_size: Optional[int] = None,
50
+ activation_fn: str = "gelu-approximate",
51
+ sample_size=32,
52
+ hidden_size=1152,
53
+ transformer_num_layers: int = 40,
54
+ mlp_ratio: float = 4.0,
55
+ cross_attention_dim: int = 1024,
56
+ cross_attention_dim_t5: int = 2048,
57
+ pooled_projection_dim: int = 1024,
58
+ text_len: int = 77,
59
+ text_len_t5: int = 256,
60
+ use_style_cond_and_image_meta_size: bool = True,
61
+ ):
62
+ super().__init__()
63
+ self.num_heads = num_attention_heads
64
+ self.inner_dim = num_attention_heads * attention_head_dim
65
+
66
+ self.text_embedder = PixArtAlphaTextProjection(
67
+ in_features=cross_attention_dim_t5,
68
+ hidden_size=cross_attention_dim_t5 * 4,
69
+ out_features=cross_attention_dim,
70
+ act_fn="silu_fp32",
71
+ )
72
+
73
+ self.text_embedding_padding = nn.Parameter(
74
+ torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
75
+ )
76
+
77
+ self.pos_embed = PatchEmbed(
78
+ height=sample_size,
79
+ width=sample_size,
80
+ in_channels=in_channels,
81
+ embed_dim=hidden_size,
82
+ patch_size=patch_size,
83
+ pos_embed_type=None,
84
+ )
85
+
86
+ self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
87
+ hidden_size,
88
+ pooled_projection_dim=pooled_projection_dim,
89
+ seq_len=text_len_t5,
90
+ cross_attention_dim=cross_attention_dim_t5,
91
+ use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
92
+ )
93
+
94
+ # controlnet_blocks
95
+ self.controlnet_blocks = nn.ModuleList([])
96
+
97
+ # HunyuanDiT Blocks
98
+ self.blocks = nn.ModuleList(
99
+ [
100
+ HunyuanDiTBlock(
101
+ dim=self.inner_dim,
102
+ num_attention_heads=self.config.num_attention_heads,
103
+ activation_fn=activation_fn,
104
+ ff_inner_dim=int(self.inner_dim * mlp_ratio),
105
+ cross_attention_dim=cross_attention_dim,
106
+ qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
107
+ skip=False, # always False as it is the first half of the model
108
+ )
109
+ for layer in range(transformer_num_layers // 2 - 1)
110
+ ]
111
+ )
112
+ self.input_block = zero_module(nn.Linear(hidden_size, hidden_size))
113
+ for _ in range(len(self.blocks)):
114
+ controlnet_block = nn.Linear(hidden_size, hidden_size)
115
+ controlnet_block = zero_module(controlnet_block)
116
+ self.controlnet_blocks.append(controlnet_block)
117
+
118
+ @property
119
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
120
+ r"""
121
+ Returns:
122
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
123
+ indexed by its weight name.
124
+ """
125
+ # set recursively
126
+ processors = {}
127
+
128
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
129
+ if hasattr(module, "get_processor"):
130
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
131
+
132
+ for sub_name, child in module.named_children():
133
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
134
+
135
+ return processors
136
+
137
+ for name, module in self.named_children():
138
+ fn_recursive_add_processors(name, module, processors)
139
+
140
+ return processors
141
+
142
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
143
+ r"""
144
+ Sets the attention processor to use to compute attention.
145
+
146
+ Parameters:
147
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
148
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
149
+ for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the
150
+ corresponding cross attention processor. This is strongly recommended when setting trainable attention
151
+ processors.
152
+ """
153
+ count = len(self.attn_processors.keys())
154
+
155
+ if isinstance(processor, dict) and len(processor) != count:
156
+ raise ValueError(
157
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
158
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
159
+ )
160
+
161
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
162
+ if hasattr(module, "set_processor"):
163
+ if not isinstance(processor, dict):
164
+ module.set_processor(processor)
165
+ else:
166
+ module.set_processor(processor.pop(f"{name}.processor"))
167
+
168
+ for sub_name, child in module.named_children():
169
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
170
+
171
+ for name, module in self.named_children():
172
+ fn_recursive_attn_processor(name, module, processor)
173
+
174
+ @classmethod
175
+ def from_transformer(
176
+ cls, transformer, conditioning_channels=3, transformer_num_layers=None, load_weights_from_transformer=True
177
+ ):
178
+ config = transformer.config
179
+ activation_fn = config.activation_fn
180
+ attention_head_dim = config.attention_head_dim
181
+ cross_attention_dim = config.cross_attention_dim
182
+ cross_attention_dim_t5 = config.cross_attention_dim_t5
183
+ hidden_size = config.hidden_size
184
+ in_channels = config.in_channels
185
+ mlp_ratio = config.mlp_ratio
186
+ num_attention_heads = config.num_attention_heads
187
+ patch_size = config.patch_size
188
+ sample_size = config.sample_size
189
+ text_len = config.text_len
190
+ text_len_t5 = config.text_len_t5
191
+
192
+ conditioning_channels = conditioning_channels
193
+ transformer_num_layers = transformer_num_layers or config.transformer_num_layers
194
+
195
+ controlnet = cls(
196
+ conditioning_channels=conditioning_channels,
197
+ transformer_num_layers=transformer_num_layers,
198
+ activation_fn=activation_fn,
199
+ attention_head_dim=attention_head_dim,
200
+ cross_attention_dim=cross_attention_dim,
201
+ cross_attention_dim_t5=cross_attention_dim_t5,
202
+ hidden_size=hidden_size,
203
+ in_channels=in_channels,
204
+ mlp_ratio=mlp_ratio,
205
+ num_attention_heads=num_attention_heads,
206
+ patch_size=patch_size,
207
+ sample_size=sample_size,
208
+ text_len=text_len,
209
+ text_len_t5=text_len_t5,
210
+ )
211
+ if load_weights_from_transformer:
212
+ key = controlnet.load_state_dict(transformer.state_dict(), strict=False)
213
+ logger.warning(f"controlnet load from Hunyuan-DiT. missing_keys: {key[0]}")
214
+ return controlnet
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states,
219
+ timestep,
220
+ controlnet_cond: torch.Tensor,
221
+ conditioning_scale: float = 1.0,
222
+ encoder_hidden_states=None,
223
+ text_embedding_mask=None,
224
+ encoder_hidden_states_t5=None,
225
+ text_embedding_mask_t5=None,
226
+ image_meta_size=None,
227
+ style=None,
228
+ image_rotary_emb=None,
229
+ return_dict=True,
230
+ ):
231
+ """
232
+ The [`HunyuanDiT2DControlNetModel`] forward method.
233
+
234
+ Args:
235
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
236
+ The input tensor.
237
+ timestep ( `torch.LongTensor`, *optional*):
238
+ Used to indicate denoising step.
239
+ controlnet_cond ( `torch.Tensor` ):
240
+ The conditioning input to ControlNet.
241
+ conditioning_scale ( `float` ):
242
+ Indicate the conditioning scale.
243
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
244
+ Conditional embeddings for cross attention layer. This is the output of `BertModel`.
245
+ text_embedding_mask: torch.Tensor
246
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
247
+ of `BertModel`.
248
+ encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
249
+ Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
250
+ text_embedding_mask_t5: torch.Tensor
251
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
252
+ of T5 Text Encoder.
253
+ image_meta_size (torch.Tensor):
254
+ Conditional embedding indicate the image sizes
255
+ style: torch.Tensor:
256
+ Conditional embedding indicate the style
257
+ image_rotary_emb (`torch.Tensor`):
258
+ The image rotary embeddings to apply on query and key tensors during attention calculation.
259
+ return_dict: bool
260
+ Whether to return a dictionary.
261
+ """
262
+
263
+ height, width = hidden_states.shape[-2:]
264
+
265
+ hidden_states = self.pos_embed(hidden_states) # b,c,H,W -> b, N, C
266
+
267
+ # 2. pre-process
268
+ hidden_states = hidden_states + self.input_block(self.pos_embed(controlnet_cond))
269
+
270
+ temb = self.time_extra_emb(
271
+ timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
272
+ ) # [B, D]
273
+
274
+ # text projection
275
+ batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
276
+ encoder_hidden_states_t5 = self.text_embedder(
277
+ encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
278
+ )
279
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
280
+
281
+ encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
282
+ text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
283
+ text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
284
+
285
+ encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
286
+
287
+ block_res_samples = ()
288
+ for layer, block in enumerate(self.blocks):
289
+ hidden_states = block(
290
+ hidden_states,
291
+ temb=temb,
292
+ encoder_hidden_states=encoder_hidden_states,
293
+ image_rotary_emb=image_rotary_emb,
294
+ ) # (N, L, D)
295
+
296
+ block_res_samples = block_res_samples + (hidden_states,)
297
+
298
+ controlnet_block_res_samples = ()
299
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
300
+ block_res_sample = controlnet_block(block_res_sample)
301
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
302
+
303
+ # 6. scaling
304
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
305
+
306
+ if not return_dict:
307
+ return (controlnet_block_res_samples,)
308
+
309
+ return HunyuanControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
310
+
311
+
312
+ class HunyuanDiT2DMultiControlNetModel(ModelMixin):
313
+ r"""
314
+ `HunyuanDiT2DMultiControlNetModel` wrapper class for Multi-HunyuanDiT2DControlNetModel
315
+
316
+ This module is a wrapper for multiple instances of the `HunyuanDiT2DControlNetModel`. The `forward()` API is
317
+ designed to be compatible with `HunyuanDiT2DControlNetModel`.
318
+
319
+ Args:
320
+ controlnets (`List[HunyuanDiT2DControlNetModel]`):
321
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
322
+ `HunyuanDiT2DControlNetModel` as a list.
323
+ """
324
+
325
+ def __init__(self, controlnets):
326
+ super().__init__()
327
+ self.nets = nn.ModuleList(controlnets)
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states,
332
+ timestep,
333
+ controlnet_cond: torch.Tensor,
334
+ conditioning_scale: float = 1.0,
335
+ encoder_hidden_states=None,
336
+ text_embedding_mask=None,
337
+ encoder_hidden_states_t5=None,
338
+ text_embedding_mask_t5=None,
339
+ image_meta_size=None,
340
+ style=None,
341
+ image_rotary_emb=None,
342
+ return_dict=True,
343
+ ):
344
+ """
345
+ The [`HunyuanDiT2DControlNetModel`] forward method.
346
+
347
+ Args:
348
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
349
+ The input tensor.
350
+ timestep ( `torch.LongTensor`, *optional*):
351
+ Used to indicate denoising step.
352
+ controlnet_cond ( `torch.Tensor` ):
353
+ The conditioning input to ControlNet.
354
+ conditioning_scale ( `float` ):
355
+ Indicate the conditioning scale.
356
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
357
+ Conditional embeddings for cross attention layer. This is the output of `BertModel`.
358
+ text_embedding_mask: torch.Tensor
359
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
360
+ of `BertModel`.
361
+ encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
362
+ Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
363
+ text_embedding_mask_t5: torch.Tensor
364
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
365
+ of T5 Text Encoder.
366
+ image_meta_size (torch.Tensor):
367
+ Conditional embedding indicate the image sizes
368
+ style: torch.Tensor:
369
+ Conditional embedding indicate the style
370
+ image_rotary_emb (`torch.Tensor`):
371
+ The image rotary embeddings to apply on query and key tensors during attention calculation.
372
+ return_dict: bool
373
+ Whether to return a dictionary.
374
+ """
375
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
376
+ block_samples = controlnet(
377
+ hidden_states=hidden_states,
378
+ timestep=timestep,
379
+ controlnet_cond=image,
380
+ conditioning_scale=scale,
381
+ encoder_hidden_states=encoder_hidden_states,
382
+ text_embedding_mask=text_embedding_mask,
383
+ encoder_hidden_states_t5=encoder_hidden_states_t5,
384
+ text_embedding_mask_t5=text_embedding_mask_t5,
385
+ image_meta_size=image_meta_size,
386
+ style=style,
387
+ image_rotary_emb=image_rotary_emb,
388
+ return_dict=return_dict,
389
+ )
390
+
391
+ # merge samples
392
+ if i == 0:
393
+ control_block_samples = block_samples
394
+ else:
395
+ control_block_samples = [
396
+ control_block_sample + block_sample
397
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
398
+ ]
399
+ control_block_samples = (control_block_samples,)
400
+
401
+ return control_block_samples
@@ -22,7 +22,7 @@ import torch.nn as nn
22
22
  from ..configuration_utils import ConfigMixin, register_to_config
23
23
  from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
24
24
  from ..models.attention import JointTransformerBlock
25
- from ..models.attention_processor import Attention, AttentionProcessor
25
+ from ..models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
26
26
  from ..models.modeling_outputs import Transformer2DModelOutput
27
27
  from ..models.modeling_utils import ModelMixin
28
28
  from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -81,7 +81,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
81
81
  JointTransformerBlock(
82
82
  dim=self.inner_dim,
83
83
  num_attention_heads=num_attention_heads,
84
- attention_head_dim=self.inner_dim,
84
+ attention_head_dim=self.config.attention_head_dim,
85
85
  context_pre_only=False,
86
86
  )
87
87
  for i in range(num_layers)
@@ -149,7 +149,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
149
149
 
150
150
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
151
151
  if hasattr(module, "get_processor"):
152
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
152
+ processors[f"{name}.processor"] = module.get_processor()
153
153
 
154
154
  for sub_name, child in module.named_children():
155
155
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -196,7 +196,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
196
196
  for name, module in self.named_children():
197
197
  fn_recursive_attn_processor(name, module, processor)
198
198
 
199
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
199
+ # Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
200
200
  def fuse_qkv_projections(self):
201
201
  """
202
202
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -220,6 +220,8 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
220
220
  if isinstance(module, Attention):
221
221
  module.fuse_projections(fuse=True)
222
222
 
223
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
224
+
223
225
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
224
226
  def unfuse_qkv_projections(self):
225
227
  """Disables the fused QKV projection if enabled.
@@ -239,16 +241,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
239
241
  module.gradient_checkpointing = value
240
242
 
241
243
  @classmethod
242
- def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
244
+ def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
243
245
  config = transformer.config
244
246
  config["num_layers"] = num_layers or config.num_layers
245
247
  controlnet = cls(**config)
246
248
 
247
249
  if load_weights_from_transformer:
248
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
249
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
250
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
251
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
250
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
251
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
252
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
253
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
252
254
 
253
255
  controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
254
256
 
@@ -308,8 +310,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
308
310
  "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
309
311
  )
310
312
 
311
- height, width = hidden_states.shape[-2:]
312
-
313
313
  hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
314
314
  temb = self.time_text_embed(timestep, pooled_projections)
315
315
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)