diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,61 +13,35 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, List, Optional, Union
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
17
 
18
+ import numpy as np
18
19
  import torch
19
20
  import torch.nn as nn
20
21
  import torch.nn.functional as F
21
22
 
22
23
  from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24
25
  from ...models.attention import FeedForward
25
- from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
26
+ from ...models.attention_processor import (
27
+ Attention,
28
+ AttentionProcessor,
29
+ FluxAttnProcessor2_0,
30
+ FluxAttnProcessor2_0_NPU,
31
+ FusedFluxAttnProcessor2_0,
32
+ )
26
33
  from ...models.modeling_utils import ModelMixin
27
34
  from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
28
35
  from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
36
+ from ...utils.import_utils import is_torch_npu_available
29
37
  from ...utils.torch_utils import maybe_allow_in_graph
30
- from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
38
+ from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
31
39
  from ..modeling_outputs import Transformer2DModelOutput
32
40
 
33
41
 
34
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
43
 
36
44
 
37
- # YiYi to-do: refactor rope related functions/classes
38
- def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
39
- assert dim % 2 == 0, "The dimension must be even."
40
-
41
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
42
- omega = 1.0 / (theta**scale)
43
-
44
- batch_size, seq_length = pos.shape
45
- out = torch.einsum("...n,d->...nd", pos, omega)
46
- cos_out = torch.cos(out)
47
- sin_out = torch.sin(out)
48
-
49
- stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
50
- out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
51
- return out.float()
52
-
53
-
54
- # YiYi to-do: refactor rope related functions/classes
55
- class EmbedND(nn.Module):
56
- def __init__(self, dim: int, theta: int, axes_dim: List[int]):
57
- super().__init__()
58
- self.dim = dim
59
- self.theta = theta
60
- self.axes_dim = axes_dim
61
-
62
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
63
- n_axes = ids.shape[-1]
64
- emb = torch.cat(
65
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
66
- dim=-3,
67
- )
68
- return emb.unsqueeze(1)
69
-
70
-
71
45
  @maybe_allow_in_graph
72
46
  class FluxSingleTransformerBlock(nn.Module):
73
47
  r"""
@@ -92,7 +66,10 @@ class FluxSingleTransformerBlock(nn.Module):
92
66
  self.act_mlp = nn.GELU(approximate="tanh")
93
67
  self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
94
68
 
95
- processor = FluxSingleAttnProcessor2_0()
69
+ if is_torch_npu_available():
70
+ processor = FluxAttnProcessor2_0_NPU()
71
+ else:
72
+ processor = FluxAttnProcessor2_0()
96
73
  self.attn = Attention(
97
74
  query_dim=dim,
98
75
  cross_attention_dim=None,
@@ -111,14 +88,16 @@ class FluxSingleTransformerBlock(nn.Module):
111
88
  hidden_states: torch.FloatTensor,
112
89
  temb: torch.FloatTensor,
113
90
  image_rotary_emb=None,
91
+ joint_attention_kwargs=None,
114
92
  ):
115
93
  residual = hidden_states
116
94
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
117
95
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
118
-
96
+ joint_attention_kwargs = joint_attention_kwargs or {}
119
97
  attn_output = self.attn(
120
98
  hidden_states=norm_hidden_states,
121
99
  image_rotary_emb=image_rotary_emb,
100
+ **joint_attention_kwargs,
122
101
  )
123
102
 
124
103
  hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
@@ -189,20 +168,27 @@ class FluxTransformerBlock(nn.Module):
189
168
  encoder_hidden_states: torch.FloatTensor,
190
169
  temb: torch.FloatTensor,
191
170
  image_rotary_emb=None,
171
+ joint_attention_kwargs=None,
192
172
  ):
193
173
  norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
194
174
 
195
175
  norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
196
176
  encoder_hidden_states, emb=temb
197
177
  )
198
-
178
+ joint_attention_kwargs = joint_attention_kwargs or {}
199
179
  # Attention.
200
- attn_output, context_attn_output = self.attn(
180
+ attention_outputs = self.attn(
201
181
  hidden_states=norm_hidden_states,
202
182
  encoder_hidden_states=norm_encoder_hidden_states,
203
183
  image_rotary_emb=image_rotary_emb,
184
+ **joint_attention_kwargs,
204
185
  )
205
186
 
187
+ if len(attention_outputs) == 2:
188
+ attn_output, context_attn_output = attention_outputs
189
+ elif len(attention_outputs) == 3:
190
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
191
+
206
192
  # Process attention outputs for the `hidden_states`.
207
193
  attn_output = gate_msa.unsqueeze(1) * attn_output
208
194
  hidden_states = hidden_states + attn_output
@@ -214,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
214
200
  ff_output = gate_mlp.unsqueeze(1) * ff_output
215
201
 
216
202
  hidden_states = hidden_states + ff_output
203
+ if len(attention_outputs) == 3:
204
+ hidden_states = hidden_states + ip_attn_output
217
205
 
218
206
  # Process attention outputs for the `encoder_hidden_states`.
219
207
 
@@ -231,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
231
219
  return encoder_hidden_states, hidden_states
232
220
 
233
221
 
234
- class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
222
+ class FluxTransformer2DModel(
223
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
224
+ ):
235
225
  """
236
226
  The Transformer model introduced in Flux.
237
227
 
@@ -250,12 +240,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
250
240
  """
251
241
 
252
242
  _supports_gradient_checkpointing = True
243
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
253
244
 
254
245
  @register_to_config
255
246
  def __init__(
256
247
  self,
257
248
  patch_size: int = 1,
258
249
  in_channels: int = 64,
250
+ out_channels: Optional[int] = None,
259
251
  num_layers: int = 19,
260
252
  num_single_layers: int = 38,
261
253
  attention_head_dim: int = 128,
@@ -263,13 +255,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
263
255
  joint_attention_dim: int = 4096,
264
256
  pooled_projection_dim: int = 768,
265
257
  guidance_embeds: bool = False,
266
- axes_dims_rope: List[int] = [16, 56, 56],
258
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
267
259
  ):
268
260
  super().__init__()
269
- self.out_channels = in_channels
261
+ self.out_channels = out_channels or in_channels
270
262
  self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
271
263
 
272
- self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
264
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
265
+
273
266
  text_time_guidance_cls = (
274
267
  CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
275
268
  )
@@ -278,7 +271,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
278
271
  )
279
272
 
280
273
  self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
281
- self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
274
+ self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
282
275
 
283
276
  self.transformer_blocks = nn.ModuleList(
284
277
  [
@@ -307,6 +300,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
307
300
 
308
301
  self.gradient_checkpointing = False
309
302
 
303
+ @property
304
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
305
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
306
+ r"""
307
+ Returns:
308
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
309
+ indexed by its weight name.
310
+ """
311
+ # set recursively
312
+ processors = {}
313
+
314
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
315
+ if hasattr(module, "get_processor"):
316
+ processors[f"{name}.processor"] = module.get_processor()
317
+
318
+ for sub_name, child in module.named_children():
319
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
320
+
321
+ return processors
322
+
323
+ for name, module in self.named_children():
324
+ fn_recursive_add_processors(name, module, processors)
325
+
326
+ return processors
327
+
328
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
329
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
330
+ r"""
331
+ Sets the attention processor to use to compute attention.
332
+
333
+ Parameters:
334
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
335
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
336
+ for **all** `Attention` layers.
337
+
338
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
339
+ processor. This is strongly recommended when setting trainable attention processors.
340
+
341
+ """
342
+ count = len(self.attn_processors.keys())
343
+
344
+ if isinstance(processor, dict) and len(processor) != count:
345
+ raise ValueError(
346
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
347
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
348
+ )
349
+
350
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
351
+ if hasattr(module, "set_processor"):
352
+ if not isinstance(processor, dict):
353
+ module.set_processor(processor)
354
+ else:
355
+ module.set_processor(processor.pop(f"{name}.processor"))
356
+
357
+ for sub_name, child in module.named_children():
358
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
359
+
360
+ for name, module in self.named_children():
361
+ fn_recursive_attn_processor(name, module, processor)
362
+
363
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
364
+ def fuse_qkv_projections(self):
365
+ """
366
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
367
+ are fused. For cross-attention modules, key and value projection matrices are fused.
368
+
369
+ <Tip warning={true}>
370
+
371
+ This API is 🧪 experimental.
372
+
373
+ </Tip>
374
+ """
375
+ self.original_attn_processors = None
376
+
377
+ for _, attn_processor in self.attn_processors.items():
378
+ if "Added" in str(attn_processor.__class__.__name__):
379
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
380
+
381
+ self.original_attn_processors = self.attn_processors
382
+
383
+ for module in self.modules():
384
+ if isinstance(module, Attention):
385
+ module.fuse_projections(fuse=True)
386
+
387
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
388
+
389
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
390
+ def unfuse_qkv_projections(self):
391
+ """Disables the fused QKV projection if enabled.
392
+
393
+ <Tip warning={true}>
394
+
395
+ This API is 🧪 experimental.
396
+
397
+ </Tip>
398
+
399
+ """
400
+ if self.original_attn_processors is not None:
401
+ self.set_attn_processor(self.original_attn_processors)
402
+
310
403
  def _set_gradient_checkpointing(self, module, value=False):
311
404
  if hasattr(module, "gradient_checkpointing"):
312
405
  module.gradient_checkpointing = value
@@ -321,7 +414,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
321
414
  txt_ids: torch.Tensor = None,
322
415
  guidance: torch.Tensor = None,
323
416
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
417
+ controlnet_block_samples=None,
418
+ controlnet_single_block_samples=None,
324
419
  return_dict: bool = True,
420
+ controlnet_blocks_repeat: bool = False,
325
421
  ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
326
422
  """
327
423
  The [`FluxTransformer2DModel`] forward method.
@@ -363,6 +459,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
363
459
  logger.warning(
364
460
  "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
365
461
  )
462
+
366
463
  hidden_states = self.x_embedder(hidden_states)
367
464
 
368
465
  timestep = timestep.to(hidden_states.dtype) * 1000
@@ -370,6 +467,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
370
467
  guidance = guidance.to(hidden_states.dtype) * 1000
371
468
  else:
372
469
  guidance = None
470
+
373
471
  temb = (
374
472
  self.time_text_embed(timestep, pooled_projections)
375
473
  if guidance is None
@@ -377,11 +475,29 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
377
475
  )
378
476
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)
379
477
 
380
- ids = torch.cat((txt_ids, img_ids), dim=1)
478
+ if txt_ids.ndim == 3:
479
+ logger.warning(
480
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
481
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
482
+ )
483
+ txt_ids = txt_ids[0]
484
+ if img_ids.ndim == 3:
485
+ logger.warning(
486
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
487
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
488
+ )
489
+ img_ids = img_ids[0]
490
+
491
+ ids = torch.cat((txt_ids, img_ids), dim=0)
381
492
  image_rotary_emb = self.pos_embed(ids)
382
493
 
494
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
495
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
496
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
497
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
498
+
383
499
  for index_block, block in enumerate(self.transformer_blocks):
384
- if self.training and self.gradient_checkpointing:
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
385
501
 
386
502
  def create_custom_forward(module, return_dict=None):
387
503
  def custom_forward(*inputs):
@@ -408,12 +524,24 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
408
524
  encoder_hidden_states=encoder_hidden_states,
409
525
  temb=temb,
410
526
  image_rotary_emb=image_rotary_emb,
527
+ joint_attention_kwargs=joint_attention_kwargs,
411
528
  )
412
529
 
530
+ # controlnet residual
531
+ if controlnet_block_samples is not None:
532
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
533
+ interval_control = int(np.ceil(interval_control))
534
+ # For Xlabs ControlNet.
535
+ if controlnet_blocks_repeat:
536
+ hidden_states = (
537
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
538
+ )
539
+ else:
540
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
413
541
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
414
542
 
415
543
  for index_block, block in enumerate(self.single_transformer_blocks):
416
- if self.training and self.gradient_checkpointing:
544
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
417
545
 
418
546
  def create_custom_forward(module, return_dict=None):
419
547
  def custom_forward(*inputs):
@@ -438,6 +566,16 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
438
566
  hidden_states=hidden_states,
439
567
  temb=temb,
440
568
  image_rotary_emb=image_rotary_emb,
569
+ joint_attention_kwargs=joint_attention_kwargs,
570
+ )
571
+
572
+ # controlnet residual
573
+ if controlnet_single_block_samples is not None:
574
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
575
+ interval_control = int(np.ceil(interval_control))
576
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
577
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
578
+ + controlnet_single_block_samples[index_block // interval_control]
441
579
  )
442
580
 
443
581
  hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]