diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -497,19 +497,19 @@ class TextualInversionLoaderMixin:
497
497
  # load embeddings of text_encoder 1 (CLIP ViT-L/14)
498
498
  pipeline.load_textual_inversion(
499
499
  state_dict["clip_l"],
500
- token=["<s0>", "<s1>"],
500
+ tokens=["<s0>", "<s1>"],
501
501
  text_encoder=pipeline.text_encoder,
502
502
  tokenizer=pipeline.tokenizer,
503
503
  )
504
504
  # load embeddings of text_encoder 2 (CLIP ViT-G/14)
505
505
  pipeline.load_textual_inversion(
506
506
  state_dict["clip_g"],
507
- token=["<s0>", "<s1>"],
507
+ tokens=["<s0>", "<s1>"],
508
508
  text_encoder=pipeline.text_encoder_2,
509
509
  tokenizer=pipeline.tokenizer_2,
510
510
  )
511
511
 
512
- # Unload explicitly from both text encoders abd tokenizers
512
+ # Unload explicitly from both text encoders and tokenizers
513
513
  pipeline.unload_textual_inversion(
514
514
  tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
515
515
  )
@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
561
561
  tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
562
562
  key_id += 1
563
563
  tokenizer._update_trie()
564
+ # set correct total vocab size after removing tokens
565
+ tokenizer._update_total_vocab_size()
564
566
 
565
567
  # Delete from text encoder
566
568
  text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
@@ -0,0 +1,181 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from contextlib import nullcontext
15
+
16
+ from ..models.embeddings import (
17
+ ImageProjection,
18
+ MultiIPAdapterImageProjection,
19
+ )
20
+ from ..models.modeling_utils import load_model_dict_into_meta
21
+ from ..utils import (
22
+ is_accelerate_available,
23
+ is_torch_version,
24
+ logging,
25
+ )
26
+
27
+
28
+ if is_accelerate_available():
29
+ pass
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class FluxTransformer2DLoadersMixin:
35
+ """
36
+ Load layers into a [`FluxTransformer2DModel`].
37
+ """
38
+
39
+ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
40
+ if low_cpu_mem_usage:
41
+ if is_accelerate_available():
42
+ from accelerate import init_empty_weights
43
+
44
+ else:
45
+ low_cpu_mem_usage = False
46
+ logger.warning(
47
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
48
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
49
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
50
+ " install accelerate\n```\n."
51
+ )
52
+
53
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
54
+ raise NotImplementedError(
55
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
56
+ " `low_cpu_mem_usage=False`."
57
+ )
58
+
59
+ updated_state_dict = {}
60
+ image_projection = None
61
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
62
+
63
+ if "proj.weight" in state_dict:
64
+ # IP-Adapter
65
+ num_image_text_embeds = 4
66
+ if state_dict["proj.weight"].shape[0] == 65536:
67
+ num_image_text_embeds = 16
68
+ clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
69
+ cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
70
+
71
+ with init_context():
72
+ image_projection = ImageProjection(
73
+ cross_attention_dim=cross_attention_dim,
74
+ image_embed_dim=clip_embeddings_dim,
75
+ num_image_text_embeds=num_image_text_embeds,
76
+ )
77
+
78
+ for key, value in state_dict.items():
79
+ diffusers_name = key.replace("proj", "image_embeds")
80
+ updated_state_dict[diffusers_name] = value
81
+
82
+ if not low_cpu_mem_usage:
83
+ image_projection.load_state_dict(updated_state_dict, strict=True)
84
+ else:
85
+ load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
86
+
87
+ return image_projection
88
+
89
+ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
90
+ from ..models.attention_processor import (
91
+ FluxIPAdapterJointAttnProcessor2_0,
92
+ )
93
+
94
+ if low_cpu_mem_usage:
95
+ if is_accelerate_available():
96
+ from accelerate import init_empty_weights
97
+
98
+ else:
99
+ low_cpu_mem_usage = False
100
+ logger.warning(
101
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
102
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
103
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
104
+ " install accelerate\n```\n."
105
+ )
106
+
107
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
108
+ raise NotImplementedError(
109
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
110
+ " `low_cpu_mem_usage=False`."
111
+ )
112
+
113
+ # set ip-adapter cross-attention processors & load state_dict
114
+ attn_procs = {}
115
+ key_id = 0
116
+ init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
117
+ for name in self.attn_processors.keys():
118
+ if name.startswith("single_transformer_blocks"):
119
+ attn_processor_class = self.attn_processors[name].__class__
120
+ attn_procs[name] = attn_processor_class()
121
+ else:
122
+ cross_attention_dim = self.config.joint_attention_dim
123
+ hidden_size = self.inner_dim
124
+ attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
125
+ num_image_text_embeds = []
126
+ for state_dict in state_dicts:
127
+ if "proj.weight" in state_dict["image_proj"]:
128
+ num_image_text_embed = 4
129
+ if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
130
+ num_image_text_embed = 16
131
+ # IP-Adapter
132
+ num_image_text_embeds += [num_image_text_embed]
133
+
134
+ with init_context():
135
+ attn_procs[name] = attn_processor_class(
136
+ hidden_size=hidden_size,
137
+ cross_attention_dim=cross_attention_dim,
138
+ scale=1.0,
139
+ num_tokens=num_image_text_embeds,
140
+ dtype=self.dtype,
141
+ device=self.device,
142
+ )
143
+
144
+ value_dict = {}
145
+ for i, state_dict in enumerate(state_dicts):
146
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
147
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
148
+ value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
149
+ value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
150
+
151
+ if not low_cpu_mem_usage:
152
+ attn_procs[name].load_state_dict(value_dict)
153
+ else:
154
+ device = self.device
155
+ dtype = self.dtype
156
+ load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
157
+
158
+ key_id += 1
159
+
160
+ return attn_procs
161
+
162
+ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
163
+ if not isinstance(state_dicts, list):
164
+ state_dicts = [state_dicts]
165
+
166
+ self.encoder_hid_proj = None
167
+
168
+ attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
169
+ self.set_attn_processor(attn_procs)
170
+
171
+ image_projection_layers = []
172
+ for state_dict in state_dicts:
173
+ image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
174
+ state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
175
+ )
176
+ image_projection_layers.append(image_projection_layer)
177
+
178
+ self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
179
+ self.config.encoder_hid_dim_type = "ip_image_proj"
180
+
181
+ self.to(dtype=self.dtype, device=self.device)
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict
15
+
16
+ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
17
+ from ..models.embeddings import IPAdapterTimeImageProjection
18
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19
+
20
+
21
+ class SD3Transformer2DLoadersMixin:
22
+ """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
23
+
24
+ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
25
+ """Sets IP-Adapter attention processors, image projection, and loads state_dict.
26
+
27
+ Args:
28
+ state_dict (`Dict`):
29
+ State dict with keys "ip_adapter", which contains parameters for attention processors, and
30
+ "image_proj", which contains parameters for image projection net.
31
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
32
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
33
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
34
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
35
+ argument to `True` will raise an error.
36
+ """
37
+ # IP-Adapter cross attention parameters
38
+ hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
39
+ ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
40
+ timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
41
+
42
+ # Dict where key is transformer layer index, value is attention processor's state dict
43
+ # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
44
+ layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
45
+ for key, weights in state_dict["ip_adapter"].items():
46
+ idx, name = key.split(".", maxsplit=1)
47
+ layer_state_dict[int(idx)][name] = weights
48
+
49
+ # Create IP-Adapter attention processor
50
+ attn_procs = {}
51
+ for idx, name in enumerate(self.attn_processors.keys()):
52
+ attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
53
+ hidden_size=hidden_size,
54
+ ip_hidden_states_dim=ip_hidden_states_dim,
55
+ head_dim=self.config.attention_head_dim,
56
+ timesteps_emb_dim=timesteps_emb_dim,
57
+ ).to(self.device, dtype=self.dtype)
58
+
59
+ if not low_cpu_mem_usage:
60
+ attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
61
+ else:
62
+ load_model_dict_into_meta(
63
+ attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
64
+ )
65
+
66
+ self.set_attn_processor(attn_procs)
67
+
68
+ # Image projetion parameters
69
+ embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
70
+ output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
71
+ hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
72
+ heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
73
+ num_queries = state_dict["image_proj"]["latents"].shape[1]
74
+ timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
75
+
76
+ # Image projection
77
+ self.image_proj = IPAdapterTimeImageProjection(
78
+ embed_dim=embed_dim,
79
+ output_dim=output_dim,
80
+ hidden_dim=hidden_dim,
81
+ heads=heads,
82
+ num_queries=num_queries,
83
+ timestep_in_dim=timestep_in_dim,
84
+ ).to(device=self.device, dtype=self.dtype)
85
+
86
+ if not low_cpu_mem_usage:
87
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
88
+ else:
89
+ load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
diffusers/loaders/unet.py CHANGED
@@ -36,6 +36,7 @@ from ..utils import (
36
36
  USE_PEFT_BACKEND,
37
37
  _get_model_file,
38
38
  convert_unet_state_dict_to_peft,
39
+ deprecate,
39
40
  get_adapter_name,
40
41
  get_peft_kwargs,
41
42
  is_accelerate_available,
@@ -115,6 +116,9 @@ class UNet2DConditionLoadersMixin:
115
116
  `default_{i}` where i is the total number of adapters being loaded.
116
117
  weight_name (`str`, *optional*, defaults to None):
117
118
  Name of the serialized state dict file.
119
+ low_cpu_mem_usage (`bool`, *optional*):
120
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
121
+ weights.
118
122
 
119
123
  Example:
120
124
 
@@ -142,8 +146,14 @@ class UNet2DConditionLoadersMixin:
142
146
  adapter_name = kwargs.pop("adapter_name", None)
143
147
  _pipeline = kwargs.pop("_pipeline", None)
144
148
  network_alphas = kwargs.pop("network_alphas", None)
149
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
145
150
  allow_pickle = False
146
151
 
152
+ if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
153
+ raise ValueError(
154
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
155
+ )
156
+
147
157
  if use_safetensors is None:
148
158
  use_safetensors = True
149
159
  allow_pickle = True
@@ -200,6 +210,10 @@ class UNet2DConditionLoadersMixin:
200
210
  is_model_cpu_offload = False
201
211
  is_sequential_cpu_offload = False
202
212
 
213
+ if is_lora:
214
+ deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
215
+ deprecate("load_attn_procs", "0.40.0", deprecation_message)
216
+
203
217
  if is_custom_diffusion:
204
218
  attn_processors = self._process_custom_diffusion(state_dict=state_dict)
205
219
  elif is_lora:
@@ -209,6 +223,7 @@ class UNet2DConditionLoadersMixin:
209
223
  network_alphas=network_alphas,
210
224
  adapter_name=adapter_name,
211
225
  _pipeline=_pipeline,
226
+ low_cpu_mem_usage=low_cpu_mem_usage,
212
227
  )
213
228
  else:
214
229
  raise ValueError(
@@ -268,7 +283,9 @@ class UNet2DConditionLoadersMixin:
268
283
 
269
284
  return attn_processors
270
285
 
271
- def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
286
+ def _process_lora(
287
+ self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
288
+ ):
272
289
  # This method does the following things:
273
290
  # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
274
291
  # format. For legacy format no filtering is applied.
@@ -335,18 +352,37 @@ class UNet2DConditionLoadersMixin:
335
352
  # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
336
353
  # otherwise loading LoRA weights will lead to an error
337
354
  is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
355
+ peft_kwargs = {}
356
+ if is_peft_version(">=", "0.13.1"):
357
+ peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
338
358
 
339
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
340
- incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
359
+ inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
360
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
341
361
 
362
+ warn_msg = ""
342
363
  if incompatible_keys is not None:
343
- # check only for unexpected keys
364
+ # Check only for unexpected keys.
344
365
  unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
345
366
  if unexpected_keys:
346
- logger.warning(
347
- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
348
- f" {unexpected_keys}. "
349
- )
367
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
368
+ if lora_unexpected_keys:
369
+ warn_msg = (
370
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
371
+ f" {', '.join(lora_unexpected_keys)}. "
372
+ )
373
+
374
+ # Filter missing keys specific to the current adapter.
375
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
376
+ if missing_keys:
377
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
378
+ if lora_missing_keys:
379
+ warn_msg += (
380
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
381
+ f" {', '.join(lora_missing_keys)}."
382
+ )
383
+
384
+ if warn_msg:
385
+ logger.warning(warn_msg)
350
386
 
351
387
  return is_model_cpu_offload, is_sequential_cpu_offload
352
388
 
@@ -456,6 +492,9 @@ class UNet2DConditionLoadersMixin:
456
492
  )
457
493
  state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
458
494
  else:
495
+ deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
496
+ deprecate("save_attn_procs", "0.40.0", deprecation_message)
497
+
459
498
  if not USE_PEFT_BACKEND:
460
499
  raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
461
500
 
@@ -734,6 +773,7 @@ class UNet2DConditionLoadersMixin:
734
773
  from ..models.attention_processor import (
735
774
  IPAdapterAttnProcessor,
736
775
  IPAdapterAttnProcessor2_0,
776
+ IPAdapterXFormersAttnProcessor,
737
777
  )
738
778
 
739
779
  if low_cpu_mem_usage:
@@ -773,11 +813,15 @@ class UNet2DConditionLoadersMixin:
773
813
  if cross_attention_dim is None or "motion_modules" in name:
774
814
  attn_processor_class = self.attn_processors[name].__class__
775
815
  attn_procs[name] = attn_processor_class()
776
-
777
816
  else:
778
- attn_processor_class = (
779
- IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
780
- )
817
+ if "XFormers" in str(self.attn_processors[name].__class__):
818
+ attn_processor_class = IPAdapterXFormersAttnProcessor
819
+ else:
820
+ attn_processor_class = (
821
+ IPAdapterAttnProcessor2_0
822
+ if hasattr(F, "scaled_dot_product_attention")
823
+ else IPAdapterAttnProcessor
824
+ )
781
825
  num_image_text_embeds = []
782
826
  for state_dict in state_dicts:
783
827
  if "proj.weight" in state_dict["image_proj"]:
@@ -27,18 +27,29 @@ _import_structure = {}
27
27
  if is_torch_available():
28
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
29
  _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
+ _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
30
31
  _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
32
+ _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
31
33
  _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
34
+ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
35
+ _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
36
+ _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
32
37
  _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
38
  _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
34
39
  _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
35
40
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
36
41
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
37
- _import_structure["controlnet"] = ["ControlNetModel"]
38
- _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
39
- _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
40
- _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
41
- _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
42
+ _import_structure["controlnets.controlnet"] = ["ControlNetModel"]
43
+ _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
44
+ _import_structure["controlnets.controlnet_hunyuan"] = [
45
+ "HunyuanDiT2DControlNetModel",
46
+ "HunyuanDiT2DMultiControlNetModel",
47
+ ]
48
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
49
+ _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
50
+ _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
51
+ _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
52
+ _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
42
53
  _import_structure["embeddings"] = ["ImageProjection"]
43
54
  _import_structure["modeling_utils"] = ["ModelMixin"]
44
55
  _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
@@ -50,10 +61,16 @@ if is_torch_available():
50
61
  _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
51
62
  _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
52
63
  _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
64
+ _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
53
65
  _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
54
66
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
55
67
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
68
+ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
69
+ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
56
70
  _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
71
+ _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
72
+ _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
73
+ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
57
74
  _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
58
75
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
59
76
  _import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -68,7 +85,7 @@ if is_torch_available():
68
85
  _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
69
86
 
70
87
  if is_flax_available():
71
- _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
88
+ _import_structure["controlnets.controlnet_flax"] = ["FlaxControlNetModel"]
72
89
  _import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
73
90
  _import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
74
91
 
@@ -78,32 +95,52 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
78
95
  from .adapter import MultiAdapter, T2IAdapter
79
96
  from .autoencoders import (
80
97
  AsymmetricAutoencoderKL,
98
+ AutoencoderDC,
81
99
  AutoencoderKL,
100
+ AutoencoderKLAllegro,
82
101
  AutoencoderKLCogVideoX,
102
+ AutoencoderKLHunyuanVideo,
103
+ AutoencoderKLLTXVideo,
104
+ AutoencoderKLMochi,
83
105
  AutoencoderKLTemporalDecoder,
84
106
  AutoencoderOobleck,
85
107
  AutoencoderTiny,
86
108
  ConsistencyDecoderVAE,
87
109
  VQModel,
88
110
  )
89
- from .controlnet import ControlNetModel
90
- from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
91
- from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
92
- from .controlnet_sparsectrl import SparseControlNetModel
93
- from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
111
+ from .controlnets import (
112
+ ControlNetModel,
113
+ ControlNetUnionModel,
114
+ ControlNetXSAdapter,
115
+ FluxControlNetModel,
116
+ FluxMultiControlNetModel,
117
+ HunyuanDiT2DControlNetModel,
118
+ HunyuanDiT2DMultiControlNetModel,
119
+ MultiControlNetModel,
120
+ SD3ControlNetModel,
121
+ SD3MultiControlNetModel,
122
+ SparseControlNetModel,
123
+ UNetControlNetXSModel,
124
+ )
94
125
  from .embeddings import ImageProjection
95
126
  from .modeling_utils import ModelMixin
96
127
  from .transformers import (
128
+ AllegroTransformer3DModel,
97
129
  AuraFlowTransformer2DModel,
98
130
  CogVideoXTransformer3DModel,
131
+ CogView3PlusTransformer2DModel,
99
132
  DiTTransformer2DModel,
100
133
  DualTransformer2DModel,
101
134
  FluxTransformer2DModel,
102
135
  HunyuanDiT2DModel,
136
+ HunyuanVideoTransformer3DModel,
103
137
  LatteTransformer3DModel,
138
+ LTXVideoTransformer3DModel,
104
139
  LuminaNextDiT2DModel,
140
+ MochiTransformer3DModel,
105
141
  PixArtTransformer2DModel,
106
142
  PriorTransformer,
143
+ SanaTransformer2DModel,
107
144
  SD3Transformer2DModel,
108
145
  StableAudioDiTModel,
109
146
  T5FilmDecoder,
@@ -125,7 +162,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
125
162
  )
126
163
 
127
164
  if is_flax_available():
128
- from .controlnet_flax import FlaxControlNetModel
165
+ from .controlnets import FlaxControlNetModel
129
166
  from .unets import FlaxUNet2DConditionModel
130
167
  from .vae_flax import FlaxAutoencoderKL
131
168
 
@@ -18,7 +18,7 @@ import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
20
  from ..utils import deprecate
21
- from ..utils.import_utils import is_torch_npu_available
21
+ from ..utils.import_utils import is_torch_npu_available, is_torch_version
22
22
 
23
23
 
24
24
  if is_torch_npu_available():
@@ -79,10 +79,10 @@ class GELU(nn.Module):
79
79
  self.approximate = approximate
80
80
 
81
81
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
82
- if gate.device.type != "mps":
83
- return F.gelu(gate, approximate=self.approximate)
84
- # mps: gelu is not implemented for float16
85
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
82
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
83
+ # fp16 gelu not supported on mps before torch 2.0
84
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
85
+ return F.gelu(gate, approximate=self.approximate)
86
86
 
87
87
  def forward(self, hidden_states):
88
88
  hidden_states = self.proj(hidden_states)
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
105
105
  self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
106
106
 
107
107
  def gelu(self, gate: torch.Tensor) -> torch.Tensor:
108
- if gate.device.type != "mps":
109
- return F.gelu(gate)
110
- # mps: gelu is not implemented for float16
111
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
108
+ if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
109
+ # fp16 gelu not supported on mps before torch 2.0
110
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
111
+ return F.gelu(gate)
112
112
 
113
113
  def forward(self, hidden_states, *args, **kwargs):
114
114
  if len(args) > 0 or kwargs.get("scale", None) is not None:
@@ -136,6 +136,7 @@ class SwiGLU(nn.Module):
136
136
 
137
137
  def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
138
  super().__init__()
139
+
139
140
  self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140
141
  self.activation = nn.SiLU()
141
142
 
@@ -163,3 +164,15 @@ class ApproximateGELU(nn.Module):
163
164
  def forward(self, x: torch.Tensor) -> torch.Tensor:
164
165
  x = self.proj(x)
165
166
  return x * torch.sigmoid(1.702 * x)
167
+
168
+
169
+ class LinearActivation(nn.Module):
170
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
171
+ super().__init__()
172
+
173
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
174
+ self.activation = get_activation(activation)
175
+
176
+ def forward(self, hidden_states):
177
+ hidden_states = self.proj(hidden_states)
178
+ return self.activation(hidden_states)