diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,11 @@ import jax
19
19
  import jax.numpy as jnp
20
20
  from flax.core.frozen_dict import FrozenDict
21
21
 
22
- from ..configuration_utils import ConfigMixin, flax_register_to_config
23
- from ..utils import BaseOutput
24
- from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
- from .modeling_flax_utils import FlaxModelMixin
26
- from .unets.unet_2d_blocks_flax import (
22
+ from ...configuration_utils import ConfigMixin, flax_register_to_config
23
+ from ...utils import BaseOutput
24
+ from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from ..modeling_flax_utils import FlaxModelMixin
26
+ from ..unets.unet_2d_blocks_flax import (
27
27
  FlaxCrossAttnDownBlock2D,
28
28
  FlaxDownBlock2D,
29
29
  FlaxUNetMidBlock2DCrossAttn,
@@ -0,0 +1,536 @@
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...models.attention_processor import AttentionProcessor
24
+ from ...models.modeling_utils import ModelMixin
25
+ from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
+ from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
+ from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class FluxControlNetOutput(BaseOutput):
37
+ controlnet_block_samples: Tuple[torch.Tensor]
38
+ controlnet_single_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ patch_size: int = 1,
48
+ in_channels: int = 64,
49
+ num_layers: int = 19,
50
+ num_single_layers: int = 38,
51
+ attention_head_dim: int = 128,
52
+ num_attention_heads: int = 24,
53
+ joint_attention_dim: int = 4096,
54
+ pooled_projection_dim: int = 768,
55
+ guidance_embeds: bool = False,
56
+ axes_dims_rope: List[int] = [16, 56, 56],
57
+ num_mode: int = None,
58
+ conditioning_embedding_channels: int = None,
59
+ ):
60
+ super().__init__()
61
+ self.out_channels = in_channels
62
+ self.inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
65
+ text_time_guidance_cls = (
66
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
67
+ )
68
+ self.time_text_embed = text_time_guidance_cls(
69
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
70
+ )
71
+
72
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
73
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
74
+
75
+ self.transformer_blocks = nn.ModuleList(
76
+ [
77
+ FluxTransformerBlock(
78
+ dim=self.inner_dim,
79
+ num_attention_heads=num_attention_heads,
80
+ attention_head_dim=attention_head_dim,
81
+ )
82
+ for i in range(num_layers)
83
+ ]
84
+ )
85
+
86
+ self.single_transformer_blocks = nn.ModuleList(
87
+ [
88
+ FluxSingleTransformerBlock(
89
+ dim=self.inner_dim,
90
+ num_attention_heads=num_attention_heads,
91
+ attention_head_dim=attention_head_dim,
92
+ )
93
+ for i in range(num_single_layers)
94
+ ]
95
+ )
96
+
97
+ # controlnet_blocks
98
+ self.controlnet_blocks = nn.ModuleList([])
99
+ for _ in range(len(self.transformer_blocks)):
100
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
101
+
102
+ self.controlnet_single_blocks = nn.ModuleList([])
103
+ for _ in range(len(self.single_transformer_blocks)):
104
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
105
+
106
+ self.union = num_mode is not None
107
+ if self.union:
108
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
109
+
110
+ if conditioning_embedding_channels is not None:
111
+ self.input_hint_block = ControlNetConditioningEmbedding(
112
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113
+ )
114
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115
+ else:
116
+ self.input_hint_block = None
117
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
118
+
119
+ self.gradient_checkpointing = False
120
+
121
+ @property
122
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
123
+ def attn_processors(self):
124
+ r"""
125
+ Returns:
126
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
127
+ indexed by its weight name.
128
+ """
129
+ # set recursively
130
+ processors = {}
131
+
132
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
133
+ if hasattr(module, "get_processor"):
134
+ processors[f"{name}.processor"] = module.get_processor()
135
+
136
+ for sub_name, child in module.named_children():
137
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
138
+
139
+ return processors
140
+
141
+ for name, module in self.named_children():
142
+ fn_recursive_add_processors(name, module, processors)
143
+
144
+ return processors
145
+
146
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
147
+ def set_attn_processor(self, processor):
148
+ r"""
149
+ Sets the attention processor to use to compute attention.
150
+
151
+ Parameters:
152
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
153
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
154
+ for **all** `Attention` layers.
155
+
156
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
157
+ processor. This is strongly recommended when setting trainable attention processors.
158
+
159
+ """
160
+ count = len(self.attn_processors.keys())
161
+
162
+ if isinstance(processor, dict) and len(processor) != count:
163
+ raise ValueError(
164
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
165
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
166
+ )
167
+
168
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
169
+ if hasattr(module, "set_processor"):
170
+ if not isinstance(processor, dict):
171
+ module.set_processor(processor)
172
+ else:
173
+ module.set_processor(processor.pop(f"{name}.processor"))
174
+
175
+ for sub_name, child in module.named_children():
176
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_attn_processor(name, module, processor)
180
+
181
+ def _set_gradient_checkpointing(self, module, value=False):
182
+ if hasattr(module, "gradient_checkpointing"):
183
+ module.gradient_checkpointing = value
184
+
185
+ @classmethod
186
+ def from_transformer(
187
+ cls,
188
+ transformer,
189
+ num_layers: int = 4,
190
+ num_single_layers: int = 10,
191
+ attention_head_dim: int = 128,
192
+ num_attention_heads: int = 24,
193
+ load_weights_from_transformer=True,
194
+ ):
195
+ config = dict(transformer.config)
196
+ config["num_layers"] = num_layers
197
+ config["num_single_layers"] = num_single_layers
198
+ config["attention_head_dim"] = attention_head_dim
199
+ config["num_attention_heads"] = num_attention_heads
200
+
201
+ controlnet = cls.from_config(config)
202
+
203
+ if load_weights_from_transformer:
204
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
205
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
206
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
207
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
208
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
209
+ controlnet.single_transformer_blocks.load_state_dict(
210
+ transformer.single_transformer_blocks.state_dict(), strict=False
211
+ )
212
+
213
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
214
+
215
+ return controlnet
216
+
217
+ def forward(
218
+ self,
219
+ hidden_states: torch.Tensor,
220
+ controlnet_cond: torch.Tensor,
221
+ controlnet_mode: torch.Tensor = None,
222
+ conditioning_scale: float = 1.0,
223
+ encoder_hidden_states: torch.Tensor = None,
224
+ pooled_projections: torch.Tensor = None,
225
+ timestep: torch.LongTensor = None,
226
+ img_ids: torch.Tensor = None,
227
+ txt_ids: torch.Tensor = None,
228
+ guidance: torch.Tensor = None,
229
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
230
+ return_dict: bool = True,
231
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
232
+ """
233
+ The [`FluxTransformer2DModel`] forward method.
234
+
235
+ Args:
236
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
237
+ Input `hidden_states`.
238
+ controlnet_cond (`torch.Tensor`):
239
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
240
+ controlnet_mode (`torch.Tensor`):
241
+ The mode tensor of shape `(batch_size, 1)`.
242
+ conditioning_scale (`float`, defaults to `1.0`):
243
+ The scale factor for ControlNet outputs.
244
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
245
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
246
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
247
+ from the embeddings of input conditions.
248
+ timestep ( `torch.LongTensor`):
249
+ Used to indicate denoising step.
250
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
251
+ A list of tensors that if specified are added to the residuals of transformer blocks.
252
+ joint_attention_kwargs (`dict`, *optional*):
253
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
254
+ `self.processor` in
255
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
256
+ return_dict (`bool`, *optional*, defaults to `True`):
257
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
258
+ tuple.
259
+
260
+ Returns:
261
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
262
+ `tuple` where the first element is the sample tensor.
263
+ """
264
+ if joint_attention_kwargs is not None:
265
+ joint_attention_kwargs = joint_attention_kwargs.copy()
266
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
267
+ else:
268
+ lora_scale = 1.0
269
+
270
+ if USE_PEFT_BACKEND:
271
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
272
+ scale_lora_layers(self, lora_scale)
273
+ else:
274
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
275
+ logger.warning(
276
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
277
+ )
278
+ hidden_states = self.x_embedder(hidden_states)
279
+
280
+ if self.input_hint_block is not None:
281
+ controlnet_cond = self.input_hint_block(controlnet_cond)
282
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
283
+ height = height_pw // self.config.patch_size
284
+ width = width_pw // self.config.patch_size
285
+ controlnet_cond = controlnet_cond.reshape(
286
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
287
+ )
288
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
289
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
290
+ # add
291
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
292
+
293
+ timestep = timestep.to(hidden_states.dtype) * 1000
294
+ if guidance is not None:
295
+ guidance = guidance.to(hidden_states.dtype) * 1000
296
+ else:
297
+ guidance = None
298
+ temb = (
299
+ self.time_text_embed(timestep, pooled_projections)
300
+ if guidance is None
301
+ else self.time_text_embed(timestep, guidance, pooled_projections)
302
+ )
303
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
304
+
305
+ if self.union:
306
+ # union mode
307
+ if controlnet_mode is None:
308
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
309
+ # union mode emb
310
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
311
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
312
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
313
+
314
+ if txt_ids.ndim == 3:
315
+ logger.warning(
316
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
317
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
318
+ )
319
+ txt_ids = txt_ids[0]
320
+ if img_ids.ndim == 3:
321
+ logger.warning(
322
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
323
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
324
+ )
325
+ img_ids = img_ids[0]
326
+
327
+ ids = torch.cat((txt_ids, img_ids), dim=0)
328
+ image_rotary_emb = self.pos_embed(ids)
329
+
330
+ block_samples = ()
331
+ for index_block, block in enumerate(self.transformer_blocks):
332
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
333
+
334
+ def create_custom_forward(module, return_dict=None):
335
+ def custom_forward(*inputs):
336
+ if return_dict is not None:
337
+ return module(*inputs, return_dict=return_dict)
338
+ else:
339
+ return module(*inputs)
340
+
341
+ return custom_forward
342
+
343
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ encoder_hidden_states,
348
+ temb,
349
+ image_rotary_emb,
350
+ **ckpt_kwargs,
351
+ )
352
+
353
+ else:
354
+ encoder_hidden_states, hidden_states = block(
355
+ hidden_states=hidden_states,
356
+ encoder_hidden_states=encoder_hidden_states,
357
+ temb=temb,
358
+ image_rotary_emb=image_rotary_emb,
359
+ )
360
+ block_samples = block_samples + (hidden_states,)
361
+
362
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
363
+
364
+ single_block_samples = ()
365
+ for index_block, block in enumerate(self.single_transformer_blocks):
366
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
367
+
368
+ def create_custom_forward(module, return_dict=None):
369
+ def custom_forward(*inputs):
370
+ if return_dict is not None:
371
+ return module(*inputs, return_dict=return_dict)
372
+ else:
373
+ return module(*inputs)
374
+
375
+ return custom_forward
376
+
377
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
+ hidden_states = torch.utils.checkpoint.checkpoint(
379
+ create_custom_forward(block),
380
+ hidden_states,
381
+ temb,
382
+ image_rotary_emb,
383
+ **ckpt_kwargs,
384
+ )
385
+
386
+ else:
387
+ hidden_states = block(
388
+ hidden_states=hidden_states,
389
+ temb=temb,
390
+ image_rotary_emb=image_rotary_emb,
391
+ )
392
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
393
+
394
+ # controlnet block
395
+ controlnet_block_samples = ()
396
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
397
+ block_sample = controlnet_block(block_sample)
398
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
399
+
400
+ controlnet_single_block_samples = ()
401
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
402
+ single_block_sample = controlnet_block(single_block_sample)
403
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
404
+
405
+ # scaling
406
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
407
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
408
+
409
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
410
+ controlnet_single_block_samples = (
411
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
412
+ )
413
+
414
+ if USE_PEFT_BACKEND:
415
+ # remove `lora_scale` from each PEFT layer
416
+ unscale_lora_layers(self, lora_scale)
417
+
418
+ if not return_dict:
419
+ return (controlnet_block_samples, controlnet_single_block_samples)
420
+
421
+ return FluxControlNetOutput(
422
+ controlnet_block_samples=controlnet_block_samples,
423
+ controlnet_single_block_samples=controlnet_single_block_samples,
424
+ )
425
+
426
+
427
+ class FluxMultiControlNetModel(ModelMixin):
428
+ r"""
429
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
430
+
431
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
432
+ compatible with `FluxControlNetModel`.
433
+
434
+ Args:
435
+ controlnets (`List[FluxControlNetModel]`):
436
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
437
+ `FluxControlNetModel` as a list.
438
+ """
439
+
440
+ def __init__(self, controlnets):
441
+ super().__init__()
442
+ self.nets = nn.ModuleList(controlnets)
443
+
444
+ def forward(
445
+ self,
446
+ hidden_states: torch.FloatTensor,
447
+ controlnet_cond: List[torch.tensor],
448
+ controlnet_mode: List[torch.tensor],
449
+ conditioning_scale: List[float],
450
+ encoder_hidden_states: torch.Tensor = None,
451
+ pooled_projections: torch.Tensor = None,
452
+ timestep: torch.LongTensor = None,
453
+ img_ids: torch.Tensor = None,
454
+ txt_ids: torch.Tensor = None,
455
+ guidance: torch.Tensor = None,
456
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ return_dict: bool = True,
458
+ ) -> Union[FluxControlNetOutput, Tuple]:
459
+ # ControlNet-Union with multiple conditions
460
+ # only load one ControlNet for saving memories
461
+ if len(self.nets) == 1 and self.nets[0].union:
462
+ controlnet = self.nets[0]
463
+
464
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
465
+ block_samples, single_block_samples = controlnet(
466
+ hidden_states=hidden_states,
467
+ controlnet_cond=image,
468
+ controlnet_mode=mode[:, None],
469
+ conditioning_scale=scale,
470
+ timestep=timestep,
471
+ guidance=guidance,
472
+ pooled_projections=pooled_projections,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ txt_ids=txt_ids,
475
+ img_ids=img_ids,
476
+ joint_attention_kwargs=joint_attention_kwargs,
477
+ return_dict=return_dict,
478
+ )
479
+
480
+ # merge samples
481
+ if i == 0:
482
+ control_block_samples = block_samples
483
+ control_single_block_samples = single_block_samples
484
+ else:
485
+ control_block_samples = [
486
+ control_block_sample + block_sample
487
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
488
+ ]
489
+
490
+ control_single_block_samples = [
491
+ control_single_block_sample + block_sample
492
+ for control_single_block_sample, block_sample in zip(
493
+ control_single_block_samples, single_block_samples
494
+ )
495
+ ]
496
+
497
+ # Regular Multi-ControlNets
498
+ # load all ControlNets into memories
499
+ else:
500
+ for i, (image, mode, scale, controlnet) in enumerate(
501
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
502
+ ):
503
+ block_samples, single_block_samples = controlnet(
504
+ hidden_states=hidden_states,
505
+ controlnet_cond=image,
506
+ controlnet_mode=mode[:, None],
507
+ conditioning_scale=scale,
508
+ timestep=timestep,
509
+ guidance=guidance,
510
+ pooled_projections=pooled_projections,
511
+ encoder_hidden_states=encoder_hidden_states,
512
+ txt_ids=txt_ids,
513
+ img_ids=img_ids,
514
+ joint_attention_kwargs=joint_attention_kwargs,
515
+ return_dict=return_dict,
516
+ )
517
+
518
+ # merge samples
519
+ if i == 0:
520
+ control_block_samples = block_samples
521
+ control_single_block_samples = single_block_samples
522
+ else:
523
+ if block_samples is not None and control_block_samples is not None:
524
+ control_block_samples = [
525
+ control_block_sample + block_sample
526
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
527
+ ]
528
+ if single_block_samples is not None and control_single_block_samples is not None:
529
+ control_single_block_samples = [
530
+ control_single_block_sample + block_sample
531
+ for control_single_block_sample, block_sample in zip(
532
+ control_single_block_samples, single_block_samples
533
+ )
534
+ ]
535
+
536
+ return control_block_samples, control_single_block_samples
@@ -17,17 +17,17 @@ from typing import Dict, Optional, Union
17
17
  import torch
18
18
  from torch import nn
19
19
 
20
- from ..configuration_utils import ConfigMixin, register_to_config
21
- from ..utils import logging
22
- from .attention_processor import AttentionProcessor
23
- from .controlnet import BaseOutput, Tuple, zero_module
24
- from .embeddings import (
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import BaseOutput, logging
22
+ from ..attention_processor import AttentionProcessor
23
+ from ..embeddings import (
25
24
  HunyuanCombinedTimestepTextSizeStyleEmbedding,
26
25
  PatchEmbed,
27
26
  PixArtAlphaTextProjection,
28
27
  )
29
- from .modeling_utils import ModelMixin
30
- from .transformers.hunyuan_transformer_2d import HunyuanDiTBlock
28
+ from ..modeling_utils import ModelMixin
29
+ from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
30
+ from .controlnet import Tuple, zero_module
31
31
 
32
32
 
33
33
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name