diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ from collections import OrderedDict
21
+ from typing import List, Optional, Union
22
+
23
+ import safetensors
24
+ import torch
25
+
26
+ from ..utils import (
27
+ SAFETENSORS_FILE_EXTENSION,
28
+ is_accelerate_available,
29
+ is_torch_version,
30
+ logging,
31
+ )
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+ _CLASS_REMAPPING_DICT = {
37
+ "Transformer2DModel": {
38
+ "ada_norm_zero": "DiTTransformer2DModel",
39
+ "ada_norm_single": "PixArtTransformer2DModel",
40
+ }
41
+ }
42
+
43
+
44
+ if is_accelerate_available():
45
+ from accelerate import infer_auto_device_map
46
+ from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
47
+
48
+
49
+ # Adapted from `transformers` (see modeling_utils.py)
50
+ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
51
+ if isinstance(device_map, str):
52
+ no_split_modules = model._get_no_split_modules(device_map)
53
+ device_map_kwargs = {"no_split_module_classes": no_split_modules}
54
+
55
+ if device_map != "sequential":
56
+ max_memory = get_balanced_memory(
57
+ model,
58
+ dtype=torch_dtype,
59
+ low_zero=(device_map == "balanced_low_0"),
60
+ max_memory=max_memory,
61
+ **device_map_kwargs,
62
+ )
63
+ else:
64
+ max_memory = get_max_memory(max_memory)
65
+
66
+ device_map_kwargs["max_memory"] = max_memory
67
+ device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
68
+
69
+ return device_map
70
+
71
+
72
+ def _fetch_remapped_cls_from_config(config, old_class):
73
+ previous_class_name = old_class.__name__
74
+ remapped_class_name = _CLASS_REMAPPING_DICT.get(previous_class_name).get(config["norm_type"], None)
75
+
76
+ # Details:
77
+ # https://github.com/huggingface/diffusers/pull/7647#discussion_r1621344818
78
+ if remapped_class_name:
79
+ # load diffusers library to import compatible and original scheduler
80
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
81
+ remapped_class = getattr(diffusers_library, remapped_class_name)
82
+ logger.info(
83
+ f"Changing class object to be of `{remapped_class_name}` type from `{previous_class_name}` type."
84
+ f"This is because `{previous_class_name}` is scheduled to be deprecated in a future version. Note that this"
85
+ " DOESN'T affect the final results."
86
+ )
87
+ return remapped_class
88
+ else:
89
+ return old_class
90
+
91
+
92
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
93
+ """
94
+ Reads a checkpoint file, returning properly formatted errors if they arise.
95
+ """
96
+ try:
97
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
98
+ if file_extension == SAFETENSORS_FILE_EXTENSION:
99
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
100
+ else:
101
+ weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
102
+ return torch.load(
103
+ checkpoint_file,
104
+ map_location="cpu",
105
+ **weights_only_kwarg,
106
+ )
107
+ except Exception as e:
108
+ try:
109
+ with open(checkpoint_file) as f:
110
+ if f.read().startswith("version"):
111
+ raise OSError(
112
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
113
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
114
+ "you cloned."
115
+ )
116
+ else:
117
+ raise ValueError(
118
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
119
+ "model. Make sure you have saved the model properly."
120
+ ) from e
121
+ except (UnicodeDecodeError, ValueError):
122
+ raise OSError(
123
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
124
+ )
125
+
126
+
127
+ def load_model_dict_into_meta(
128
+ model,
129
+ state_dict: OrderedDict,
130
+ device: Optional[Union[str, torch.device]] = None,
131
+ dtype: Optional[Union[str, torch.dtype]] = None,
132
+ model_name_or_path: Optional[str] = None,
133
+ ) -> List[str]:
134
+ device = device or torch.device("cpu")
135
+ dtype = dtype or torch.float32
136
+
137
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
138
+
139
+ unexpected_keys = []
140
+ empty_state_dict = model.state_dict()
141
+ for param_name, param in state_dict.items():
142
+ if param_name not in empty_state_dict:
143
+ unexpected_keys.append(param_name)
144
+ continue
145
+
146
+ if empty_state_dict[param_name].shape != param.shape:
147
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
148
+ raise ValueError(
149
+ f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
150
+ )
151
+
152
+ if accepts_dtype:
153
+ set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
154
+ else:
155
+ set_module_tensor_to_device(model, param_name, device, value=param)
156
+ return unexpected_keys
157
+
158
+
159
+ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
160
+ # Convert old format to new format if needed from a PyTorch state_dict
161
+ # copy state_dict so _load_from_state_dict can modify it
162
+ state_dict = state_dict.copy()
163
+ error_msgs = []
164
+
165
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
166
+ # so we need to apply the function recursively.
167
+ def load(module: torch.nn.Module, prefix: str = ""):
168
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
169
+ module._load_from_state_dict(*args)
170
+
171
+ for name, child in module._modules.items():
172
+ if child is not None:
173
+ load(child, prefix + name + ".")
174
+
175
+ load(model_to_load)
176
+
177
+ return error_msgs
@@ -12,7 +12,8 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
15
+ """PyTorch - Flax general utilities."""
16
+
16
17
  import re
17
18
 
18
19
  import jax.numpy as jnp
@@ -245,9 +245,9 @@ class FlaxModelMixin(PushToHubMixin):
245
245
  force_download (`bool`, *optional*, defaults to `False`):
246
246
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
247
247
  cached versions if they exist.
248
- resume_download (`bool`, *optional*, defaults to `False`):
249
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
250
- incompletely downloaded files are deleted.
248
+ resume_download:
249
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
250
+ of Diffusers.
251
251
  proxies (`Dict[str, str]`, *optional*):
252
252
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
253
253
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -296,7 +296,7 @@ class FlaxModelMixin(PushToHubMixin):
296
296
  cache_dir = kwargs.pop("cache_dir", None)
297
297
  force_download = kwargs.pop("force_download", False)
298
298
  from_pt = kwargs.pop("from_pt", False)
299
- resume_download = kwargs.pop("resume_download", False)
299
+ resume_download = kwargs.pop("resume_download", None)
300
300
  proxies = kwargs.pop("proxies", None)
301
301
  local_files_only = kwargs.pop("local_files_only", False)
302
302
  token = kwargs.pop("token", None)
@@ -15,3 +15,17 @@ class AutoencoderKLOutput(BaseOutput):
15
15
  """
16
16
 
17
17
  latent_dist: "DiagonalGaussianDistribution" # noqa: F821
18
+
19
+
20
+ @dataclass
21
+ class Transformer2DModelOutput(BaseOutput):
22
+ """
23
+ The output of [`Transformer2DModel`].
24
+
25
+ Args:
26
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
27
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
28
+ distributions for the unnoised latent pixels.
29
+ """
30
+
31
+ sample: "torch.Tensor" # noqa: F821
@@ -12,7 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
15
+ """PyTorch - Flax general utilities."""
16
16
 
17
17
  from pickle import UnpicklingError
18
18
 
@@ -20,6 +20,7 @@ import os
20
20
  import re
21
21
  from collections import OrderedDict
22
22
  from functools import partial
23
+ from pathlib import Path
23
24
  from typing import Any, Callable, List, Optional, Tuple, Union
24
25
 
25
26
  import safetensors
@@ -32,7 +33,6 @@ from .. import __version__
32
33
  from ..utils import (
33
34
  CONFIG_NAME,
34
35
  FLAX_WEIGHTS_NAME,
35
- SAFETENSORS_FILE_EXTENSION,
36
36
  SAFETENSORS_WEIGHTS_NAME,
37
37
  WEIGHTS_NAME,
38
38
  _add_variant,
@@ -42,7 +42,17 @@ from ..utils import (
42
42
  is_torch_version,
43
43
  logging,
44
44
  )
45
- from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
45
+ from ..utils.hub_utils import (
46
+ PushToHubMixin,
47
+ load_or_create_model_card,
48
+ populate_model_card,
49
+ )
50
+ from .model_loading_utils import (
51
+ _determine_device_map,
52
+ _load_state_dict_into_model,
53
+ load_model_dict_into_meta,
54
+ load_state_dict,
55
+ )
46
56
 
47
57
 
48
58
  logger = logging.get_logger(__name__)
@@ -56,8 +66,6 @@ else:
56
66
 
57
67
  if is_accelerate_available():
58
68
  import accelerate
59
- from accelerate.utils import set_module_tensor_to_device
60
- from accelerate.utils.versions import is_torch_version
61
69
 
62
70
 
63
71
  def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
@@ -98,89 +106,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
98
106
  return first_tuple[1].dtype
99
107
 
100
108
 
101
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
102
- """
103
- Reads a checkpoint file, returning properly formatted errors if they arise.
104
- """
105
- try:
106
- file_extension = os.path.basename(checkpoint_file).split(".")[-1]
107
- if file_extension == SAFETENSORS_FILE_EXTENSION:
108
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
109
- else:
110
- return torch.load(checkpoint_file, map_location="cpu")
111
- except Exception as e:
112
- try:
113
- with open(checkpoint_file) as f:
114
- if f.read().startswith("version"):
115
- raise OSError(
116
- "You seem to have cloned a repository without having git-lfs installed. Please install "
117
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
118
- "you cloned."
119
- )
120
- else:
121
- raise ValueError(
122
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
123
- "model. Make sure you have saved the model properly."
124
- ) from e
125
- except (UnicodeDecodeError, ValueError):
126
- raise OSError(
127
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
128
- )
129
-
130
-
131
- def load_model_dict_into_meta(
132
- model,
133
- state_dict: OrderedDict,
134
- device: Optional[Union[str, torch.device]] = None,
135
- dtype: Optional[Union[str, torch.dtype]] = None,
136
- model_name_or_path: Optional[str] = None,
137
- ) -> List[str]:
138
- device = device or torch.device("cpu")
139
- dtype = dtype or torch.float32
140
-
141
- accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
142
-
143
- unexpected_keys = []
144
- empty_state_dict = model.state_dict()
145
- for param_name, param in state_dict.items():
146
- if param_name not in empty_state_dict:
147
- unexpected_keys.append(param_name)
148
- continue
149
-
150
- if empty_state_dict[param_name].shape != param.shape:
151
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
152
- raise ValueError(
153
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
154
- )
155
-
156
- if accepts_dtype:
157
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
158
- else:
159
- set_module_tensor_to_device(model, param_name, device, value=param)
160
- return unexpected_keys
161
-
162
-
163
- def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
164
- # Convert old format to new format if needed from a PyTorch state_dict
165
- # copy state_dict so _load_from_state_dict can modify it
166
- state_dict = state_dict.copy()
167
- error_msgs = []
168
-
169
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
170
- # so we need to apply the function recursively.
171
- def load(module: torch.nn.Module, prefix: str = ""):
172
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
173
- module._load_from_state_dict(*args)
174
-
175
- for name, child in module._modules.items():
176
- if child is not None:
177
- load(child, prefix + name + ".")
178
-
179
- load(model_to_load)
180
-
181
- return error_msgs
182
-
183
-
184
109
  class ModelMixin(torch.nn.Module, PushToHubMixin):
185
110
  r"""
186
111
  Base class for all models.
@@ -195,6 +120,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
195
120
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
196
121
  _supports_gradient_checkpointing = False
197
122
  _keys_to_ignore_on_load_unexpected = None
123
+ _no_split_modules = None
198
124
 
199
125
  def __init__(self):
200
126
  super().__init__()
@@ -241,6 +167,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
241
167
  if self._supports_gradient_checkpointing:
242
168
  self.apply(partial(self._set_gradient_checkpointing, value=False))
243
169
 
170
+ def set_use_npu_flash_attention(self, valid: bool) -> None:
171
+ r"""
172
+ Set the switch for the npu flash attention.
173
+ """
174
+
175
+ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
176
+ if hasattr(module, "set_use_npu_flash_attention"):
177
+ module.set_use_npu_flash_attention(valid)
178
+
179
+ for child in module.children():
180
+ fn_recursive_set_npu_flash_attention(child)
181
+
182
+ for module in self.children():
183
+ if isinstance(module, torch.nn.Module):
184
+ fn_recursive_set_npu_flash_attention(module)
185
+
186
+ def enable_npu_flash_attention(self) -> None:
187
+ r"""
188
+ Enable npu flash attention from torch_npu
189
+
190
+ """
191
+ self.set_use_npu_flash_attention(True)
192
+
193
+ def disable_npu_flash_attention(self) -> None:
194
+ r"""
195
+ disable npu flash attention from torch_npu
196
+
197
+ """
198
+ self.set_use_npu_flash_attention(False)
199
+
244
200
  def set_use_memory_efficient_attention_xformers(
245
201
  self, valid: bool, attention_op: Optional[Callable] = None
246
202
  ) -> None:
@@ -367,18 +323,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
367
323
  # Save the model
368
324
  if safe_serialization:
369
325
  safetensors.torch.save_file(
370
- state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
326
+ state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
371
327
  )
372
328
  else:
373
- torch.save(state_dict, os.path.join(save_directory, weights_name))
329
+ torch.save(state_dict, Path(save_directory, weights_name).as_posix())
374
330
 
375
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
331
+ logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
376
332
 
377
333
  if push_to_hub:
378
334
  # Create a new empty model card and eventually tag it
379
335
  model_card = load_or_create_model_card(repo_id, token=token)
380
336
  model_card = populate_model_card(model_card)
381
- model_card.save(os.path.join(save_directory, "README.md"))
337
+ model_card.save(Path(save_directory, "README.md").as_posix())
382
338
 
383
339
  self._upload_folder(
384
340
  save_directory,
@@ -415,9 +371,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
415
371
  force_download (`bool`, *optional*, defaults to `False`):
416
372
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
417
373
  cached versions if they exist.
418
- resume_download (`bool`, *optional*, defaults to `False`):
419
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
420
- incompletely downloaded files are deleted.
374
+ resume_download:
375
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
376
+ of Diffusers.
421
377
  proxies (`Dict[str, str]`, *optional*):
422
378
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
423
379
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -499,7 +455,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
499
455
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
500
456
  force_download = kwargs.pop("force_download", False)
501
457
  from_flax = kwargs.pop("from_flax", False)
502
- resume_download = kwargs.pop("resume_download", False)
458
+ resume_download = kwargs.pop("resume_download", None)
503
459
  proxies = kwargs.pop("proxies", None)
504
460
  output_loading_info = kwargs.pop("output_loading_info", False)
505
461
  local_files_only = kwargs.pop("local_files_only", None)
@@ -554,6 +510,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
554
510
  " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
555
511
  )
556
512
 
513
+ # change device_map into a map if we passed an int, a str or a torch.device
514
+ if isinstance(device_map, torch.device):
515
+ device_map = {"": device_map}
516
+ elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
517
+ try:
518
+ device_map = {"": torch.device(device_map)}
519
+ except RuntimeError:
520
+ raise ValueError(
521
+ "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
522
+ f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
523
+ )
524
+ elif isinstance(device_map, int):
525
+ if device_map < 0:
526
+ raise ValueError(
527
+ "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
528
+ )
529
+ else:
530
+ device_map = {"": device_map}
531
+
532
+ if device_map is not None:
533
+ if low_cpu_mem_usage is None:
534
+ low_cpu_mem_usage = True
535
+ elif not low_cpu_mem_usage:
536
+ raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
537
+
538
+ if low_cpu_mem_usage:
539
+ if device_map is not None and not is_torch_version(">=", "1.10"):
540
+ # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
541
+ raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
542
+
557
543
  # Load config if we don't provide a configuration
558
544
  config_path = pretrained_model_name_or_path
559
545
 
@@ -576,10 +562,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
576
562
  token=token,
577
563
  revision=revision,
578
564
  subfolder=subfolder,
579
- device_map=device_map,
580
- max_memory=max_memory,
581
- offload_folder=offload_folder,
582
- offload_state_dict=offload_state_dict,
583
565
  user_agent=user_agent,
584
566
  **kwargs,
585
567
  )
@@ -684,6 +666,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
684
666
  else: # else let accelerate handle loading and dispatching.
685
667
  # Load weights and dispatch according to the device_map
686
668
  # by default the device_map is None and the weights are loaded on the CPU
669
+ device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
687
670
  try:
688
671
  accelerate.load_checkpoint_and_dispatch(
689
672
  model,
@@ -693,6 +676,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
693
676
  offload_folder=offload_folder,
694
677
  offload_state_dict=offload_state_dict,
695
678
  dtype=torch_dtype,
679
+ force_hooks=True,
680
+ strict=True,
696
681
  )
697
682
  except AttributeError as e:
698
683
  # When using accelerate loading, we do not have the ability to load the state
@@ -873,6 +858,45 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
873
858
 
874
859
  return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
875
860
 
861
+ @classmethod
862
+ def _get_signature_keys(cls, obj):
863
+ parameters = inspect.signature(obj.__init__).parameters
864
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
865
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
866
+ expected_modules = set(required_parameters.keys()) - {"self"}
867
+
868
+ return expected_modules, optional_parameters
869
+
870
+ # Adapted from `transformers` modeling_utils.py
871
+ def _get_no_split_modules(self, device_map: str):
872
+ """
873
+ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
874
+ get the underlying `_no_split_modules`.
875
+
876
+ Args:
877
+ device_map (`str`):
878
+ The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
879
+
880
+ Returns:
881
+ `List[str]`: List of modules that should not be split
882
+ """
883
+ _no_split_modules = set()
884
+ modules_to_check = [self]
885
+ while len(modules_to_check) > 0:
886
+ module = modules_to_check.pop(-1)
887
+ # if the module does not appear in _no_split_modules, we also check the children
888
+ if module.__class__.__name__ not in _no_split_modules:
889
+ if isinstance(module, ModelMixin):
890
+ if module._no_split_modules is None:
891
+ raise ValueError(
892
+ f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
893
+ "class needs to implement the `_no_split_modules` attribute."
894
+ )
895
+ else:
896
+ _no_split_modules = _no_split_modules | set(module._no_split_modules)
897
+ modules_to_check += list(module.children())
898
+ return list(_no_split_modules)
899
+
876
900
  @property
877
901
  def device(self) -> torch.device:
878
902
  """
@@ -1019,3 +1043,55 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1019
1043
  del module.key
1020
1044
  del module.value
1021
1045
  del module.proj_attn
1046
+
1047
+
1048
+ class LegacyModelMixin(ModelMixin):
1049
+ r"""
1050
+ A subclass of `ModelMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
1051
+ pipeline-specific classes (like `DiTTransformer2DModel`).
1052
+ """
1053
+
1054
+ @classmethod
1055
+ @validate_hf_hub_args
1056
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1057
+ # To prevent depedency import problem.
1058
+ from .model_loading_utils import _fetch_remapped_cls_from_config
1059
+
1060
+ cache_dir = kwargs.pop("cache_dir", None)
1061
+ force_download = kwargs.pop("force_download", False)
1062
+ resume_download = kwargs.pop("resume_download", None)
1063
+ proxies = kwargs.pop("proxies", None)
1064
+ local_files_only = kwargs.pop("local_files_only", None)
1065
+ token = kwargs.pop("token", None)
1066
+ revision = kwargs.pop("revision", None)
1067
+ subfolder = kwargs.pop("subfolder", None)
1068
+
1069
+ # Load config if we don't provide a configuration
1070
+ config_path = pretrained_model_name_or_path
1071
+
1072
+ user_agent = {
1073
+ "diffusers": __version__,
1074
+ "file_type": "model",
1075
+ "framework": "pytorch",
1076
+ }
1077
+
1078
+ # load config
1079
+ config, _, _ = cls.load_config(
1080
+ config_path,
1081
+ cache_dir=cache_dir,
1082
+ return_unused_kwargs=True,
1083
+ return_commit_hash=True,
1084
+ force_download=force_download,
1085
+ resume_download=resume_download,
1086
+ proxies=proxies,
1087
+ local_files_only=local_files_only,
1088
+ token=token,
1089
+ revision=revision,
1090
+ subfolder=subfolder,
1091
+ user_agent=user_agent,
1092
+ **kwargs,
1093
+ )
1094
+ # resolve remapping
1095
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
1096
+
1097
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
@@ -176,7 +176,8 @@ class AdaLayerNormContinuous(nn.Module):
176
176
  raise ValueError(f"unknown norm_type {norm_type}")
177
177
 
178
178
  def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
179
- emb = self.linear(self.silu(conditioning_embedding))
179
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
180
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
180
181
  scale, shift = torch.chunk(emb, 2, dim=1)
181
182
  x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
182
183
  return x