diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,455 @@
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...models.attention import FeedForward
25
+ from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
26
+ from ...models.modeling_utils import ModelMixin
27
+ from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
28
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
+ from ...utils.torch_utils import maybe_allow_in_graph
30
+ from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
31
+ from ..modeling_outputs import Transformer2DModelOutput
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ # YiYi to-do: refactor rope related functions/classes
38
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
39
+ assert dim % 2 == 0, "The dimension must be even."
40
+
41
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
42
+ omega = 1.0 / (theta**scale)
43
+
44
+ batch_size, seq_length = pos.shape
45
+ out = torch.einsum("...n,d->...nd", pos, omega)
46
+ cos_out = torch.cos(out)
47
+ sin_out = torch.sin(out)
48
+
49
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
50
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
51
+ return out.float()
52
+
53
+
54
+ # YiYi to-do: refactor rope related functions/classes
55
+ class EmbedND(nn.Module):
56
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
57
+ super().__init__()
58
+ self.dim = dim
59
+ self.theta = theta
60
+ self.axes_dim = axes_dim
61
+
62
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
63
+ n_axes = ids.shape[-1]
64
+ emb = torch.cat(
65
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
66
+ dim=-3,
67
+ )
68
+ return emb.unsqueeze(1)
69
+
70
+
71
+ @maybe_allow_in_graph
72
+ class FluxSingleTransformerBlock(nn.Module):
73
+ r"""
74
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
75
+
76
+ Reference: https://arxiv.org/abs/2403.03206
77
+
78
+ Parameters:
79
+ dim (`int`): The number of channels in the input and output.
80
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
81
+ attention_head_dim (`int`): The number of channels in each head.
82
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
83
+ processing of `context` conditions.
84
+ """
85
+
86
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
87
+ super().__init__()
88
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
89
+
90
+ self.norm = AdaLayerNormZeroSingle(dim)
91
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
92
+ self.act_mlp = nn.GELU(approximate="tanh")
93
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
94
+
95
+ processor = FluxSingleAttnProcessor2_0()
96
+ self.attn = Attention(
97
+ query_dim=dim,
98
+ cross_attention_dim=None,
99
+ dim_head=attention_head_dim,
100
+ heads=num_attention_heads,
101
+ out_dim=dim,
102
+ bias=True,
103
+ processor=processor,
104
+ qk_norm="rms_norm",
105
+ eps=1e-6,
106
+ pre_only=True,
107
+ )
108
+
109
+ def forward(
110
+ self,
111
+ hidden_states: torch.FloatTensor,
112
+ temb: torch.FloatTensor,
113
+ image_rotary_emb=None,
114
+ ):
115
+ residual = hidden_states
116
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
117
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
118
+
119
+ attn_output = self.attn(
120
+ hidden_states=norm_hidden_states,
121
+ image_rotary_emb=image_rotary_emb,
122
+ )
123
+
124
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
125
+ gate = gate.unsqueeze(1)
126
+ hidden_states = gate * self.proj_out(hidden_states)
127
+ hidden_states = residual + hidden_states
128
+ if hidden_states.dtype == torch.float16:
129
+ hidden_states = hidden_states.clip(-65504, 65504)
130
+
131
+ return hidden_states
132
+
133
+
134
+ @maybe_allow_in_graph
135
+ class FluxTransformerBlock(nn.Module):
136
+ r"""
137
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
138
+
139
+ Reference: https://arxiv.org/abs/2403.03206
140
+
141
+ Parameters:
142
+ dim (`int`): The number of channels in the input and output.
143
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
144
+ attention_head_dim (`int`): The number of channels in each head.
145
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
146
+ processing of `context` conditions.
147
+ """
148
+
149
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
150
+ super().__init__()
151
+
152
+ self.norm1 = AdaLayerNormZero(dim)
153
+
154
+ self.norm1_context = AdaLayerNormZero(dim)
155
+
156
+ if hasattr(F, "scaled_dot_product_attention"):
157
+ processor = FluxAttnProcessor2_0()
158
+ else:
159
+ raise ValueError(
160
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
161
+ )
162
+ self.attn = Attention(
163
+ query_dim=dim,
164
+ cross_attention_dim=None,
165
+ added_kv_proj_dim=dim,
166
+ dim_head=attention_head_dim,
167
+ heads=num_attention_heads,
168
+ out_dim=dim,
169
+ context_pre_only=False,
170
+ bias=True,
171
+ processor=processor,
172
+ qk_norm=qk_norm,
173
+ eps=eps,
174
+ )
175
+
176
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
177
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
178
+
179
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
180
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
181
+
182
+ # let chunk size default to None
183
+ self._chunk_size = None
184
+ self._chunk_dim = 0
185
+
186
+ def forward(
187
+ self,
188
+ hidden_states: torch.FloatTensor,
189
+ encoder_hidden_states: torch.FloatTensor,
190
+ temb: torch.FloatTensor,
191
+ image_rotary_emb=None,
192
+ ):
193
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
194
+
195
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
196
+ encoder_hidden_states, emb=temb
197
+ )
198
+
199
+ # Attention.
200
+ attn_output, context_attn_output = self.attn(
201
+ hidden_states=norm_hidden_states,
202
+ encoder_hidden_states=norm_encoder_hidden_states,
203
+ image_rotary_emb=image_rotary_emb,
204
+ )
205
+
206
+ # Process attention outputs for the `hidden_states`.
207
+ attn_output = gate_msa.unsqueeze(1) * attn_output
208
+ hidden_states = hidden_states + attn_output
209
+
210
+ norm_hidden_states = self.norm2(hidden_states)
211
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
212
+
213
+ ff_output = self.ff(norm_hidden_states)
214
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
215
+
216
+ hidden_states = hidden_states + ff_output
217
+
218
+ # Process attention outputs for the `encoder_hidden_states`.
219
+
220
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
221
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
222
+
223
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
224
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
225
+
226
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
227
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
228
+ if encoder_hidden_states.dtype == torch.float16:
229
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
230
+
231
+ return encoder_hidden_states, hidden_states
232
+
233
+
234
+ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
235
+ """
236
+ The Transformer model introduced in Flux.
237
+
238
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
239
+
240
+ Parameters:
241
+ patch_size (`int`): Patch size to turn the input data into small patches.
242
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
243
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
244
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
245
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
246
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
247
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
248
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
249
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
250
+ """
251
+
252
+ _supports_gradient_checkpointing = True
253
+
254
+ @register_to_config
255
+ def __init__(
256
+ self,
257
+ patch_size: int = 1,
258
+ in_channels: int = 64,
259
+ num_layers: int = 19,
260
+ num_single_layers: int = 38,
261
+ attention_head_dim: int = 128,
262
+ num_attention_heads: int = 24,
263
+ joint_attention_dim: int = 4096,
264
+ pooled_projection_dim: int = 768,
265
+ guidance_embeds: bool = False,
266
+ axes_dims_rope: List[int] = [16, 56, 56],
267
+ ):
268
+ super().__init__()
269
+ self.out_channels = in_channels
270
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
271
+
272
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
273
+ text_time_guidance_cls = (
274
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
275
+ )
276
+ self.time_text_embed = text_time_guidance_cls(
277
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
278
+ )
279
+
280
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
281
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
282
+
283
+ self.transformer_blocks = nn.ModuleList(
284
+ [
285
+ FluxTransformerBlock(
286
+ dim=self.inner_dim,
287
+ num_attention_heads=self.config.num_attention_heads,
288
+ attention_head_dim=self.config.attention_head_dim,
289
+ )
290
+ for i in range(self.config.num_layers)
291
+ ]
292
+ )
293
+
294
+ self.single_transformer_blocks = nn.ModuleList(
295
+ [
296
+ FluxSingleTransformerBlock(
297
+ dim=self.inner_dim,
298
+ num_attention_heads=self.config.num_attention_heads,
299
+ attention_head_dim=self.config.attention_head_dim,
300
+ )
301
+ for i in range(self.config.num_single_layers)
302
+ ]
303
+ )
304
+
305
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
306
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
307
+
308
+ self.gradient_checkpointing = False
309
+
310
+ def _set_gradient_checkpointing(self, module, value=False):
311
+ if hasattr(module, "gradient_checkpointing"):
312
+ module.gradient_checkpointing = value
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.Tensor,
317
+ encoder_hidden_states: torch.Tensor = None,
318
+ pooled_projections: torch.Tensor = None,
319
+ timestep: torch.LongTensor = None,
320
+ img_ids: torch.Tensor = None,
321
+ txt_ids: torch.Tensor = None,
322
+ guidance: torch.Tensor = None,
323
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
324
+ return_dict: bool = True,
325
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
326
+ """
327
+ The [`FluxTransformer2DModel`] forward method.
328
+
329
+ Args:
330
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
331
+ Input `hidden_states`.
332
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
333
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
334
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
335
+ from the embeddings of input conditions.
336
+ timestep ( `torch.LongTensor`):
337
+ Used to indicate denoising step.
338
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
339
+ A list of tensors that if specified are added to the residuals of transformer blocks.
340
+ joint_attention_kwargs (`dict`, *optional*):
341
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
342
+ `self.processor` in
343
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
344
+ return_dict (`bool`, *optional*, defaults to `True`):
345
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
346
+ tuple.
347
+
348
+ Returns:
349
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
350
+ `tuple` where the first element is the sample tensor.
351
+ """
352
+ if joint_attention_kwargs is not None:
353
+ joint_attention_kwargs = joint_attention_kwargs.copy()
354
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
355
+ else:
356
+ lora_scale = 1.0
357
+
358
+ if USE_PEFT_BACKEND:
359
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
360
+ scale_lora_layers(self, lora_scale)
361
+ else:
362
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
363
+ logger.warning(
364
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
365
+ )
366
+ hidden_states = self.x_embedder(hidden_states)
367
+
368
+ timestep = timestep.to(hidden_states.dtype) * 1000
369
+ if guidance is not None:
370
+ guidance = guidance.to(hidden_states.dtype) * 1000
371
+ else:
372
+ guidance = None
373
+ temb = (
374
+ self.time_text_embed(timestep, pooled_projections)
375
+ if guidance is None
376
+ else self.time_text_embed(timestep, guidance, pooled_projections)
377
+ )
378
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
379
+
380
+ ids = torch.cat((txt_ids, img_ids), dim=1)
381
+ image_rotary_emb = self.pos_embed(ids)
382
+
383
+ for index_block, block in enumerate(self.transformer_blocks):
384
+ if self.training and self.gradient_checkpointing:
385
+
386
+ def create_custom_forward(module, return_dict=None):
387
+ def custom_forward(*inputs):
388
+ if return_dict is not None:
389
+ return module(*inputs, return_dict=return_dict)
390
+ else:
391
+ return module(*inputs)
392
+
393
+ return custom_forward
394
+
395
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
396
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
397
+ create_custom_forward(block),
398
+ hidden_states,
399
+ encoder_hidden_states,
400
+ temb,
401
+ image_rotary_emb,
402
+ **ckpt_kwargs,
403
+ )
404
+
405
+ else:
406
+ encoder_hidden_states, hidden_states = block(
407
+ hidden_states=hidden_states,
408
+ encoder_hidden_states=encoder_hidden_states,
409
+ temb=temb,
410
+ image_rotary_emb=image_rotary_emb,
411
+ )
412
+
413
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
414
+
415
+ for index_block, block in enumerate(self.single_transformer_blocks):
416
+ if self.training and self.gradient_checkpointing:
417
+
418
+ def create_custom_forward(module, return_dict=None):
419
+ def custom_forward(*inputs):
420
+ if return_dict is not None:
421
+ return module(*inputs, return_dict=return_dict)
422
+ else:
423
+ return module(*inputs)
424
+
425
+ return custom_forward
426
+
427
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
428
+ hidden_states = torch.utils.checkpoint.checkpoint(
429
+ create_custom_forward(block),
430
+ hidden_states,
431
+ temb,
432
+ image_rotary_emb,
433
+ **ckpt_kwargs,
434
+ )
435
+
436
+ else:
437
+ hidden_states = block(
438
+ hidden_states=hidden_states,
439
+ temb=temb,
440
+ image_rotary_emb=image_rotary_emb,
441
+ )
442
+
443
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
444
+
445
+ hidden_states = self.norm_out(hidden_states, temb)
446
+ output = self.proj_out(hidden_states)
447
+
448
+ if USE_PEFT_BACKEND:
449
+ # remove `lora_scale` from each PEFT layer
450
+ unscale_lora_layers(self, lora_scale)
451
+
452
+ if not return_dict:
453
+ return (output,)
454
+
455
+ return Transformer2DModelOutput(sample=output)
@@ -21,7 +21,7 @@ import torch.nn as nn
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
23
  from ...models.attention import JointTransformerBlock
24
- from ...models.attention_processor import Attention, AttentionProcessor
24
+ from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
25
  from ...models.modeling_utils import ModelMixin
26
26
  from ...models.normalization import AdaLayerNormContinuous
27
27
  from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
@@ -95,7 +95,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
95
95
  JointTransformerBlock(
96
96
  dim=self.inner_dim,
97
97
  num_attention_heads=self.config.num_attention_heads,
98
- attention_head_dim=self.inner_dim,
98
+ attention_head_dim=self.config.attention_head_dim,
99
99
  context_pre_only=i == num_layers - 1,
100
100
  )
101
101
  for i in range(self.config.num_layers)
@@ -137,6 +137,18 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
137
137
  for module in self.children():
138
138
  fn_recursive_feed_forward(module, chunk_size, dim)
139
139
 
140
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
141
+ def disable_forward_chunking(self):
142
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
143
+ if hasattr(module, "set_chunk_feed_forward"):
144
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
145
+
146
+ for child in module.children():
147
+ fn_recursive_feed_forward(child, chunk_size, dim)
148
+
149
+ for module in self.children():
150
+ fn_recursive_feed_forward(module, None, 0)
151
+
140
152
  @property
141
153
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
142
154
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -150,7 +162,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
150
162
 
151
163
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
152
164
  if hasattr(module, "get_processor"):
153
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
165
+ processors[f"{name}.processor"] = module.get_processor()
154
166
 
155
167
  for sub_name, child in module.named_children():
156
168
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -197,7 +209,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
197
209
  for name, module in self.named_children():
198
210
  fn_recursive_attn_processor(name, module, processor)
199
211
 
200
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
212
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
201
213
  def fuse_qkv_projections(self):
202
214
  """
203
215
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -221,6 +233,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
221
233
  if isinstance(module, Attention):
222
234
  module.fuse_projections(fuse=True)
223
235
 
236
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
237
+
224
238
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
225
239
  def unfuse_qkv_projections(self):
226
240
  """Disables the fused QKV projection if enabled.
@@ -200,7 +200,7 @@ class MidResTemporalBlock1D(nn.Module):
200
200
 
201
201
  self.upsample = None
202
202
  if add_upsample:
203
- self.upsample = Downsample1D(out_channels, use_conv=True)
203
+ self.upsample = Upsample1D(out_channels, use_conv=True)
204
204
 
205
205
  self.downsample = None
206
206
  if add_downsample:
@@ -30,6 +30,7 @@ from ..attention_processor import (
30
30
  AttentionProcessor,
31
31
  AttnAddedKVProcessor,
32
32
  AttnProcessor,
33
+ FusedAttnProcessor2_0,
33
34
  )
34
35
  from ..embeddings import (
35
36
  GaussianFourierProjection,
@@ -705,7 +706,7 @@ class UNet2DConditionModel(
705
706
 
706
707
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
707
708
  if hasattr(module, "get_processor"):
708
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
709
+ processors[f"{name}.processor"] = module.get_processor()
709
710
 
710
711
  for sub_name, child in module.named_children():
711
712
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -890,6 +891,8 @@ class UNet2DConditionModel(
890
891
  if isinstance(module, Attention):
891
892
  module.fuse_projections(fuse=True)
892
893
 
894
+ self.set_attn_processor(FusedAttnProcessor2_0())
895
+
893
896
  def unfuse_qkv_projections(self):
894
897
  """Disables the fused QKV projection if enabled.
895
898
 
@@ -1024,6 +1027,10 @@ class UNet2DConditionModel(
1024
1027
  raise ValueError(
1025
1028
  f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1026
1029
  )
1030
+
1031
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
1032
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1033
+
1027
1034
  image_embeds = added_cond_kwargs.get("image_embeds")
1028
1035
  image_embeds = self.encoder_hid_proj(image_embeds)
1029
1036
  encoder_hidden_states = (encoder_hidden_states, image_embeds)