diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -22,15 +22,19 @@ from pathlib import Path
22
22
  from typing import Any, Dict, List, Optional, Union
23
23
 
24
24
  import torch
25
- from huggingface_hub import (
26
- model_info,
27
- )
25
+ from huggingface_hub import model_info
26
+ from huggingface_hub.utils import validate_hf_hub_args
28
27
  from packaging import version
29
28
 
29
+ from .. import __version__
30
30
  from ..utils import (
31
+ FLAX_WEIGHTS_NAME,
32
+ ONNX_EXTERNAL_WEIGHTS_NAME,
33
+ ONNX_WEIGHTS_NAME,
31
34
  SAFETENSORS_WEIGHTS_NAME,
32
35
  WEIGHTS_NAME,
33
36
  get_class_from_dynamic_module,
37
+ is_accelerate_available,
34
38
  is_peft_available,
35
39
  is_transformers_available,
36
40
  logging,
@@ -44,9 +48,12 @@ if is_transformers_available():
44
48
  from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
45
49
  from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
46
50
  from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
47
- from huggingface_hub.utils import validate_hf_hub_args
48
51
 
49
- from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
52
+ if is_accelerate_available():
53
+ import accelerate
54
+ from accelerate import dispatch_model
55
+ from accelerate.hooks import remove_hook_from_module
56
+ from accelerate.utils import compute_module_sizes, get_max_memory
50
57
 
51
58
 
52
59
  INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -292,6 +299,39 @@ def get_class_obj_and_candidates(
292
299
  return class_obj, class_candidates
293
300
 
294
301
 
302
+ def _get_custom_pipeline_class(
303
+ custom_pipeline,
304
+ repo_id=None,
305
+ hub_revision=None,
306
+ class_name=None,
307
+ cache_dir=None,
308
+ revision=None,
309
+ ):
310
+ if custom_pipeline.endswith(".py"):
311
+ path = Path(custom_pipeline)
312
+ # decompose into folder & file
313
+ file_name = path.name
314
+ custom_pipeline = path.parent.absolute()
315
+ elif repo_id is not None:
316
+ file_name = f"{custom_pipeline}.py"
317
+ custom_pipeline = repo_id
318
+ else:
319
+ file_name = CUSTOM_PIPELINE_FILE_NAME
320
+
321
+ if repo_id is not None and hub_revision is not None:
322
+ # if we load the pipeline code from the Hub
323
+ # make sure to overwrite the `revision`
324
+ revision = hub_revision
325
+
326
+ return get_class_from_dynamic_module(
327
+ custom_pipeline,
328
+ module_file=file_name,
329
+ class_name=class_name,
330
+ cache_dir=cache_dir,
331
+ revision=revision,
332
+ )
333
+
334
+
295
335
  def _get_pipeline_class(
296
336
  class_obj,
297
337
  config=None,
@@ -304,25 +344,10 @@ def _get_pipeline_class(
304
344
  revision=None,
305
345
  ):
306
346
  if custom_pipeline is not None:
307
- if custom_pipeline.endswith(".py"):
308
- path = Path(custom_pipeline)
309
- # decompose into folder & file
310
- file_name = path.name
311
- custom_pipeline = path.parent.absolute()
312
- elif repo_id is not None:
313
- file_name = f"{custom_pipeline}.py"
314
- custom_pipeline = repo_id
315
- else:
316
- file_name = CUSTOM_PIPELINE_FILE_NAME
317
-
318
- if repo_id is not None and hub_revision is not None:
319
- # if we load the pipeline code from the Hub
320
- # make sure to overwrite the `revision`
321
- revision = hub_revision
322
-
323
- return get_class_from_dynamic_module(
347
+ return _get_custom_pipeline_class(
324
348
  custom_pipeline,
325
- module_file=file_name,
349
+ repo_id=repo_id,
350
+ hub_revision=hub_revision,
326
351
  class_name=class_name,
327
352
  cache_dir=cache_dir,
328
353
  revision=revision,
@@ -358,6 +383,209 @@ def _get_pipeline_class(
358
383
  return pipeline_cls
359
384
 
360
385
 
386
+ def _load_empty_model(
387
+ library_name: str,
388
+ class_name: str,
389
+ importable_classes: List[Any],
390
+ pipelines: Any,
391
+ is_pipeline_module: bool,
392
+ name: str,
393
+ torch_dtype: Union[str, torch.dtype],
394
+ cached_folder: Union[str, os.PathLike],
395
+ **kwargs,
396
+ ):
397
+ # retrieve class objects.
398
+ class_obj, _ = get_class_obj_and_candidates(
399
+ library_name,
400
+ class_name,
401
+ importable_classes,
402
+ pipelines,
403
+ is_pipeline_module,
404
+ component_name=name,
405
+ cache_dir=cached_folder,
406
+ )
407
+
408
+ if is_transformers_available():
409
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
410
+ else:
411
+ transformers_version = "N/A"
412
+
413
+ # Determine library.
414
+ is_transformers_model = (
415
+ is_transformers_available()
416
+ and issubclass(class_obj, PreTrainedModel)
417
+ and transformers_version >= version.parse("4.20.0")
418
+ )
419
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
420
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
421
+
422
+ model = None
423
+ config_path = cached_folder
424
+ user_agent = {
425
+ "diffusers": __version__,
426
+ "file_type": "model",
427
+ "framework": "pytorch",
428
+ }
429
+
430
+ if is_diffusers_model:
431
+ # Load config and then the model on meta.
432
+ config, unused_kwargs, commit_hash = class_obj.load_config(
433
+ os.path.join(config_path, name),
434
+ cache_dir=cached_folder,
435
+ return_unused_kwargs=True,
436
+ return_commit_hash=True,
437
+ force_download=kwargs.pop("force_download", False),
438
+ resume_download=kwargs.pop("resume_download", None),
439
+ proxies=kwargs.pop("proxies", None),
440
+ local_files_only=kwargs.pop("local_files_only", False),
441
+ token=kwargs.pop("token", None),
442
+ revision=kwargs.pop("revision", None),
443
+ subfolder=kwargs.pop("subfolder", None),
444
+ user_agent=user_agent,
445
+ )
446
+ with accelerate.init_empty_weights():
447
+ model = class_obj.from_config(config, **unused_kwargs)
448
+ elif is_transformers_model:
449
+ config_class = getattr(class_obj, "config_class", None)
450
+ if config_class is None:
451
+ raise ValueError("`config_class` cannot be None. Please double-check the model.")
452
+
453
+ config = config_class.from_pretrained(
454
+ cached_folder,
455
+ subfolder=name,
456
+ force_download=kwargs.pop("force_download", False),
457
+ resume_download=kwargs.pop("resume_download", None),
458
+ proxies=kwargs.pop("proxies", None),
459
+ local_files_only=kwargs.pop("local_files_only", False),
460
+ token=kwargs.pop("token", None),
461
+ revision=kwargs.pop("revision", None),
462
+ user_agent=user_agent,
463
+ )
464
+ with accelerate.init_empty_weights():
465
+ model = class_obj(config)
466
+
467
+ if model is not None:
468
+ model = model.to(dtype=torch_dtype)
469
+ return model
470
+
471
+
472
+ def _assign_components_to_devices(
473
+ module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
474
+ ):
475
+ device_ids = list(device_memory.keys())
476
+ device_cycle = device_ids + device_ids[::-1]
477
+ device_memory = device_memory.copy()
478
+
479
+ device_id_component_mapping = {}
480
+ current_device_index = 0
481
+ for component in module_sizes:
482
+ device_id = device_cycle[current_device_index % len(device_cycle)]
483
+ component_memory = module_sizes[component]
484
+ curr_device_memory = device_memory[device_id]
485
+
486
+ # If the GPU doesn't fit the current component offload to the CPU.
487
+ if component_memory > curr_device_memory:
488
+ device_id_component_mapping["cpu"] = [component]
489
+ else:
490
+ if device_id not in device_id_component_mapping:
491
+ device_id_component_mapping[device_id] = [component]
492
+ else:
493
+ device_id_component_mapping[device_id].append(component)
494
+
495
+ # Update the device memory.
496
+ device_memory[device_id] -= component_memory
497
+ current_device_index += 1
498
+
499
+ return device_id_component_mapping
500
+
501
+
502
+ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
503
+ # To avoid circular import problem.
504
+ from diffusers import pipelines
505
+
506
+ torch_dtype = kwargs.get("torch_dtype", torch.float32)
507
+
508
+ # Load each module in the pipeline on a meta device so that we can derive the device map.
509
+ init_empty_modules = {}
510
+ for name, (library_name, class_name) in init_dict.items():
511
+ if class_name.startswith("Flax"):
512
+ raise ValueError("Flax pipelines are not supported with `device_map`.")
513
+
514
+ # Define all importable classes
515
+ is_pipeline_module = hasattr(pipelines, library_name)
516
+ importable_classes = ALL_IMPORTABLE_CLASSES
517
+ loaded_sub_model = None
518
+
519
+ # Use passed sub model or load class_name from library_name
520
+ if name in passed_class_obj:
521
+ # if the model is in a pipeline module, then we load it from the pipeline
522
+ # check that passed_class_obj has correct parent class
523
+ maybe_raise_or_warn(
524
+ library_name,
525
+ library,
526
+ class_name,
527
+ importable_classes,
528
+ passed_class_obj,
529
+ name,
530
+ is_pipeline_module,
531
+ )
532
+ with accelerate.init_empty_weights():
533
+ loaded_sub_model = passed_class_obj[name]
534
+
535
+ else:
536
+ loaded_sub_model = _load_empty_model(
537
+ library_name=library_name,
538
+ class_name=class_name,
539
+ importable_classes=importable_classes,
540
+ pipelines=pipelines,
541
+ is_pipeline_module=is_pipeline_module,
542
+ pipeline_class=pipeline_class,
543
+ name=name,
544
+ torch_dtype=torch_dtype,
545
+ cached_folder=kwargs.get("cached_folder", None),
546
+ force_download=kwargs.get("force_download", None),
547
+ resume_download=kwargs.get("resume_download", None),
548
+ proxies=kwargs.get("proxies", None),
549
+ local_files_only=kwargs.get("local_files_only", None),
550
+ token=kwargs.get("token", None),
551
+ revision=kwargs.get("revision", None),
552
+ )
553
+
554
+ if loaded_sub_model is not None:
555
+ init_empty_modules[name] = loaded_sub_model
556
+
557
+ # determine device map
558
+ # Obtain a sorted dictionary for mapping the model-level components
559
+ # to their sizes.
560
+ module_sizes = {
561
+ module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
562
+ for module_name, module in init_empty_modules.items()
563
+ if isinstance(module, torch.nn.Module)
564
+ }
565
+ module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
566
+
567
+ # Obtain maximum memory available per device (GPUs only).
568
+ max_memory = get_max_memory(max_memory)
569
+ max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
570
+ max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
571
+
572
+ # Obtain a dictionary mapping the model-level components to the available
573
+ # devices based on the maximum memory and the model sizes.
574
+ final_device_map = None
575
+ if len(max_memory) > 0:
576
+ device_id_component_mapping = _assign_components_to_devices(
577
+ module_sizes, max_memory, device_mapping_strategy=device_map
578
+ )
579
+
580
+ # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
581
+ final_device_map = {}
582
+ for device_id, components in device_id_component_mapping.items():
583
+ for component in components:
584
+ final_device_map[component] = device_id
585
+
586
+ return final_device_map
587
+
588
+
361
589
  def load_sub_model(
362
590
  library_name: str,
363
591
  class_name: str,
@@ -381,6 +609,7 @@ def load_sub_model(
381
609
  ):
382
610
  """Helper method to load the module `name` from `library_name` and `class_name`"""
383
611
  # retrieve class candidates
612
+
384
613
  class_obj, class_candidates = get_class_obj_and_candidates(
385
614
  library_name,
386
615
  class_name,
@@ -475,6 +704,22 @@ def load_sub_model(
475
704
  # else load from the root directory
476
705
  loaded_sub_model = load_method(cached_folder, **loading_kwargs)
477
706
 
707
+ if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
708
+ # remove hooks
709
+ remove_hook_from_module(loaded_sub_model, recurse=True)
710
+ needs_offloading_to_cpu = device_map[""] == "cpu"
711
+
712
+ if needs_offloading_to_cpu:
713
+ dispatch_model(
714
+ loaded_sub_model,
715
+ state_dict=loaded_sub_model.state_dict(),
716
+ device_map=device_map,
717
+ force_hooks=True,
718
+ main_device=0,
719
+ )
720
+ else:
721
+ dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
722
+
478
723
  return loaded_sub_model
479
724
 
480
725