diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import re
21
21
  import sys
22
22
  from dataclasses import dataclass
23
23
  from pathlib import Path
24
- from typing import Any, Callable, Dict, List, Optional, Union
24
+ from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
25
25
 
26
26
  import numpy as np
27
27
  import PIL.Image
@@ -43,7 +43,7 @@ from .. import __version__
43
43
  from ..configuration_utils import ConfigMixin
44
44
  from ..models import AutoencoderKL
45
45
  from ..models.attention_processor import FusedAttnProcessor2_0
46
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
46
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
47
47
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48
48
  from ..utils import (
49
49
  CONFIG_NAME,
@@ -72,6 +72,8 @@ from .pipeline_loading_utils import (
72
72
  CUSTOM_PIPELINE_FILE_NAME,
73
73
  LOADABLE_CLASSES,
74
74
  _fetch_class_library_tuple,
75
+ _get_custom_pipeline_class,
76
+ _get_final_device_map,
75
77
  _get_pipeline_class,
76
78
  _unwrap_model,
77
79
  is_safetensors_compatible,
@@ -90,6 +92,8 @@ LIBRARIES = []
90
92
  for library in LOADABLE_CLASSES:
91
93
  LIBRARIES.append(library)
92
94
 
95
+ SUPPORTED_DEVICE_MAP = ["balanced"]
96
+
93
97
  logger = logging.get_logger(__name__)
94
98
 
95
99
 
@@ -140,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
140
144
 
141
145
  config_name = "model_index.json"
142
146
  model_cpu_offload_seq = None
147
+ hf_device_map = None
143
148
  _optional_components = []
144
149
  _exclude_from_cpu_offload = []
145
150
  _load_connected_pipes = False
@@ -371,8 +376,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
371
376
  if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
372
377
  return False
373
378
 
374
- return hasattr(module, "_hf_hook") and not isinstance(
375
- module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
379
+ return hasattr(module, "_hf_hook") and (
380
+ isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
381
+ or hasattr(module._hf_hook, "hooks")
382
+ and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
376
383
  )
377
384
 
378
385
  def module_is_offloaded(module):
@@ -390,6 +397,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
390
397
  "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
391
398
  )
392
399
 
400
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
401
+ if is_pipeline_device_mapped:
402
+ raise ValueError(
403
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
404
+ )
405
+
393
406
  # Display a warning in this case (the operation succeeds but the benefits are lost)
394
407
  pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
395
408
  if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
@@ -520,9 +533,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
520
533
  cache_dir (`Union[str, os.PathLike]`, *optional*):
521
534
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
522
535
  is not used.
523
- resume_download (`bool`, *optional*, defaults to `False`):
524
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
525
- incompletely downloaded files are deleted.
536
+ resume_download:
537
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
538
+ of Diffusers.
526
539
  proxies (`Dict[str, str]`, *optional*):
527
540
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
528
541
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -539,7 +552,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
539
552
  allowed by Git.
540
553
  custom_revision (`str`, *optional*):
541
554
  The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
542
- `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
555
+ `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
556
+ version.
543
557
  mirror (`str`, *optional*):
544
558
  Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
545
559
  guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -611,7 +625,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
611
625
  ```
612
626
  """
613
627
  cache_dir = kwargs.pop("cache_dir", None)
614
- resume_download = kwargs.pop("resume_download", False)
628
+ resume_download = kwargs.pop("resume_download", None)
615
629
  force_download = kwargs.pop("force_download", False)
616
630
  proxies = kwargs.pop("proxies", None)
617
631
  local_files_only = kwargs.pop("local_files_only", None)
@@ -642,18 +656,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
642
656
  " install accelerate\n```\n."
643
657
  )
644
658
 
659
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
660
+ raise NotImplementedError(
661
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
662
+ " `low_cpu_mem_usage=False`."
663
+ )
664
+
645
665
  if device_map is not None and not is_torch_version(">=", "1.9.0"):
646
666
  raise NotImplementedError(
647
667
  "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
648
668
  " `device_map=None`."
649
669
  )
650
670
 
651
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
671
+ if device_map is not None and not is_accelerate_available():
652
672
  raise NotImplementedError(
653
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
654
- " `low_cpu_mem_usage=False`."
673
+ "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
655
674
  )
656
675
 
676
+ if device_map is not None and not isinstance(device_map, str):
677
+ raise ValueError("`device_map` must be a string.")
678
+
679
+ if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
680
+ raise NotImplementedError(
681
+ f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
682
+ )
683
+
684
+ if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
685
+ if is_accelerate_version("<", "0.28.0"):
686
+ raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
687
+
657
688
  if low_cpu_mem_usage is False and device_map is not None:
658
689
  raise ValueError(
659
690
  f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
@@ -729,6 +760,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
729
760
  revision=custom_revision,
730
761
  )
731
762
 
763
+ if device_map is not None and pipeline_class._load_connected_pipes:
764
+ raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
765
+
732
766
  # DEPRECATED: To be removed in 1.0.0
733
767
  if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
734
768
  version.parse(config_dict["_diffusers_version"]).base_version
@@ -795,17 +829,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
795
829
  # import it here to avoid circular import
796
830
  from diffusers import pipelines
797
831
 
798
- # 6. Load each module in the pipeline
832
+ # 6. device map delegation
833
+ final_device_map = None
834
+ if device_map is not None:
835
+ final_device_map = _get_final_device_map(
836
+ device_map=device_map,
837
+ pipeline_class=pipeline_class,
838
+ passed_class_obj=passed_class_obj,
839
+ init_dict=init_dict,
840
+ library=library,
841
+ max_memory=max_memory,
842
+ torch_dtype=torch_dtype,
843
+ cached_folder=cached_folder,
844
+ force_download=force_download,
845
+ resume_download=resume_download,
846
+ proxies=proxies,
847
+ local_files_only=local_files_only,
848
+ token=token,
849
+ revision=revision,
850
+ )
851
+
852
+ # 7. Load each module in the pipeline
853
+ current_device_map = None
799
854
  for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
800
- # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
855
+ if final_device_map is not None and len(final_device_map) > 0:
856
+ component_device = final_device_map.get(name, None)
857
+ if component_device is not None:
858
+ current_device_map = {"": component_device}
859
+ else:
860
+ current_device_map = None
861
+
862
+ # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
801
863
  class_name = class_name[4:] if class_name.startswith("Flax") else class_name
802
864
 
803
- # 6.2 Define all importable classes
865
+ # 7.2 Define all importable classes
804
866
  is_pipeline_module = hasattr(pipelines, library_name)
805
867
  importable_classes = ALL_IMPORTABLE_CLASSES
806
868
  loaded_sub_model = None
807
869
 
808
- # 6.3 Use passed sub model or load class_name from library_name
870
+ # 7.3 Use passed sub model or load class_name from library_name
809
871
  if name in passed_class_obj:
810
872
  # if the model is in a pipeline module, then we load it from the pipeline
811
873
  # check that passed_class_obj has correct parent class
@@ -826,7 +888,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
826
888
  torch_dtype=torch_dtype,
827
889
  provider=provider,
828
890
  sess_options=sess_options,
829
- device_map=device_map,
891
+ device_map=current_device_map,
830
892
  max_memory=max_memory,
831
893
  offload_folder=offload_folder,
832
894
  offload_state_dict=offload_state_dict,
@@ -893,7 +955,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
893
955
  {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
894
956
  )
895
957
 
896
- # 7. Potentially add passed objects if expected
958
+ # 8. Potentially add passed objects if expected
897
959
  missing_modules = set(expected_modules) - set(init_kwargs.keys())
898
960
  passed_modules = list(passed_class_obj.keys())
899
961
  optional_modules = pipeline_class._optional_components
@@ -906,11 +968,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
906
968
  f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
907
969
  )
908
970
 
909
- # 8. Instantiate the pipeline
971
+ # 10. Instantiate the pipeline
910
972
  model = pipeline_class(**init_kwargs)
911
973
 
912
- # 9. Save where the model was instantiated from
974
+ # 11. Save where the model was instantiated from
913
975
  model.register_to_config(_name_or_path=pretrained_model_name_or_path)
976
+ if device_map is not None:
977
+ setattr(model, "hf_device_map", final_device_map)
914
978
  return model
915
979
 
916
980
  @property
@@ -939,6 +1003,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
939
1003
  return torch.device(module._hf_hook.execution_device)
940
1004
  return self.device
941
1005
 
1006
+ def remove_all_hooks(self):
1007
+ r"""
1008
+ Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
1009
+ """
1010
+ for _, model in self.components.items():
1011
+ if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
1012
+ accelerate.hooks.remove_hook_from_module(model, recurse=True)
1013
+ self._all_hooks = []
1014
+
942
1015
  def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
943
1016
  r"""
944
1017
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -953,6 +1026,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
953
1026
  The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
954
1027
  default to "cuda".
955
1028
  """
1029
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1030
+ if is_pipeline_device_mapped:
1031
+ raise ValueError(
1032
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
1033
+ )
1034
+
956
1035
  if self.model_cpu_offload_seq is None:
957
1036
  raise ValueError(
958
1037
  "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
@@ -963,6 +1042,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
963
1042
  else:
964
1043
  raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
965
1044
 
1045
+ self.remove_all_hooks()
1046
+
966
1047
  torch_device = torch.device(device)
967
1048
  device_index = torch_device.index
968
1049
 
@@ -979,11 +1060,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
979
1060
  device = torch.device(f"{device_type}:{self._offload_gpu_id}")
980
1061
  self._offload_device = device
981
1062
 
982
- if self.device.type != "cpu":
983
- self.to("cpu", silence_dtype_warnings=True)
984
- device_mod = getattr(torch, self.device.type, None)
985
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
986
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1063
+ self.to("cpu", silence_dtype_warnings=True)
1064
+ device_mod = getattr(torch, device.type, None)
1065
+ if hasattr(device_mod, "empty_cache") and device_mod.is_available():
1066
+ device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
987
1067
 
988
1068
  all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
989
1069
 
@@ -1021,11 +1101,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1021
1101
  # `enable_model_cpu_offload` has not be called, so silently do nothing
1022
1102
  return
1023
1103
 
1024
- for hook in self._all_hooks:
1025
- # offload model and remove hook from model
1026
- hook.offload()
1027
- hook.remove()
1028
-
1029
1104
  # make sure the model is in the same state as before calling it
1030
1105
  self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1031
1106
 
@@ -1048,6 +1123,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1048
1123
  from accelerate import cpu_offload
1049
1124
  else:
1050
1125
  raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1126
+ self.remove_all_hooks()
1127
+
1128
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1129
+ if is_pipeline_device_mapped:
1130
+ raise ValueError(
1131
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
1132
+ )
1051
1133
 
1052
1134
  torch_device = torch.device(device)
1053
1135
  device_index = torch_device.index
@@ -1083,6 +1165,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1083
1165
  offload_buffers = len(model._parameters) > 0
1084
1166
  cpu_offload(model, device, offload_buffers=offload_buffers)
1085
1167
 
1168
+ def reset_device_map(self):
1169
+ r"""
1170
+ Resets the device maps (if any) to None.
1171
+ """
1172
+ if self.hf_device_map is None:
1173
+ return
1174
+ else:
1175
+ self.remove_all_hooks()
1176
+ for name, component in self.components.items():
1177
+ if isinstance(component, torch.nn.Module):
1178
+ component.to("cpu")
1179
+ self.hf_device_map = None
1180
+
1086
1181
  @classmethod
1087
1182
  @validate_hf_hub_args
1088
1183
  def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
@@ -1121,9 +1216,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1121
1216
  force_download (`bool`, *optional*, defaults to `False`):
1122
1217
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1123
1218
  cached versions if they exist.
1124
- resume_download (`bool`, *optional*, defaults to `False`):
1125
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1126
- incompletely downloaded files are deleted.
1219
+ resume_download:
1220
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
1221
+ of Diffusers.
1127
1222
  proxies (`Dict[str, str]`, *optional*):
1128
1223
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1129
1224
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -1176,7 +1271,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1176
1271
 
1177
1272
  """
1178
1273
  cache_dir = kwargs.pop("cache_dir", None)
1179
- resume_download = kwargs.pop("resume_download", False)
1274
+ resume_download = kwargs.pop("resume_download", None)
1180
1275
  force_download = kwargs.pop("force_download", False)
1181
1276
  proxies = kwargs.pop("proxies", None)
1182
1277
  local_files_only = kwargs.pop("local_files_only", None)
@@ -1382,7 +1477,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1382
1477
 
1383
1478
  # Don't download index files of forbidden patterns either
1384
1479
  ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
1385
-
1386
1480
  re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
1387
1481
  re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
1388
1482
 
@@ -1472,6 +1566,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1472
1566
 
1473
1567
  return expected_modules, optional_parameters
1474
1568
 
1569
+ @classmethod
1570
+ def _get_signature_types(cls):
1571
+ signature_types = {}
1572
+ for k, v in inspect.signature(cls.__init__).parameters.items():
1573
+ if inspect.isclass(v.annotation):
1574
+ signature_types[k] = (v.annotation,)
1575
+ elif get_origin(v.annotation) == Union:
1576
+ signature_types[k] = get_args(v.annotation)
1577
+ else:
1578
+ logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
1579
+ return signature_types
1580
+
1475
1581
  @property
1476
1582
  def components(self) -> Dict[str, Any]:
1477
1583
  r"""
@@ -1650,6 +1756,129 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1650
1756
  for module in modules:
1651
1757
  module.set_attention_slice(slice_size)
1652
1758
 
1759
+ @classmethod
1760
+ def from_pipe(cls, pipeline, **kwargs):
1761
+ r"""
1762
+ Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
1763
+ pipeline components without reallocating additional memory.
1764
+
1765
+ Arguments:
1766
+ pipeline (`DiffusionPipeline`):
1767
+ The pipeline from which to create a new pipeline.
1768
+
1769
+ Returns:
1770
+ `DiffusionPipeline`:
1771
+ A new pipeline with the same weights and configurations as `pipeline`.
1772
+
1773
+ Examples:
1774
+
1775
+ ```py
1776
+ >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
1777
+
1778
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
1779
+ >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
1780
+ ```
1781
+ """
1782
+
1783
+ original_config = dict(pipeline.config)
1784
+ torch_dtype = kwargs.pop("torch_dtype", None)
1785
+
1786
+ # derive the pipeline class to instantiate
1787
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
1788
+ custom_revision = kwargs.pop("custom_revision", None)
1789
+
1790
+ if custom_pipeline is not None:
1791
+ pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
1792
+ else:
1793
+ pipeline_class = cls
1794
+
1795
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
1796
+ # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
1797
+ # e.g. `image_encoder` for StableDiffusionPipeline
1798
+ parameters = inspect.signature(cls.__init__).parameters
1799
+ true_optional_modules = set(
1800
+ {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
1801
+ )
1802
+
1803
+ # get the class of each component based on its type hint
1804
+ # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
1805
+ component_types = pipeline_class._get_signature_types()
1806
+
1807
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1808
+ # allow users pass modules in `kwargs` to override the original pipeline's components
1809
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1810
+
1811
+ original_class_obj = {}
1812
+ for name, component in pipeline.components.items():
1813
+ if name in expected_modules and name not in passed_class_obj:
1814
+ # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
1815
+ if (
1816
+ not isinstance(component, ModelMixin)
1817
+ or type(component) in component_types[name]
1818
+ or (component is None and name in cls._optional_components)
1819
+ ):
1820
+ original_class_obj[name] = component
1821
+ else:
1822
+ logger.warning(
1823
+ f"component {name} is not switched over to new pipeline because type does not match the expected."
1824
+ f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
1825
+ f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
1826
+ )
1827
+
1828
+ # allow users pass optional kwargs to override the original pipelines config attribute
1829
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1830
+ original_pipe_kwargs = {
1831
+ k: original_config[k]
1832
+ for k in original_config.keys()
1833
+ if k in optional_kwargs and k not in passed_pipe_kwargs
1834
+ }
1835
+
1836
+ # config attribute that were not expected by pipeline is stored as its private attribute
1837
+ # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
1838
+ # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
1839
+ additional_pipe_kwargs = [
1840
+ k[1:]
1841
+ for k in original_config.keys()
1842
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1843
+ ]
1844
+ for k in additional_pipe_kwargs:
1845
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1846
+
1847
+ pipeline_kwargs = {
1848
+ **passed_class_obj,
1849
+ **original_class_obj,
1850
+ **passed_pipe_kwargs,
1851
+ **original_pipe_kwargs,
1852
+ **kwargs,
1853
+ }
1854
+
1855
+ # store unused config as private attribute in the new pipeline
1856
+ unused_original_config = {
1857
+ f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
1858
+ }
1859
+
1860
+ missing_modules = (
1861
+ set(expected_modules)
1862
+ - set(pipeline._optional_components)
1863
+ - set(pipeline_kwargs.keys())
1864
+ - set(true_optional_modules)
1865
+ )
1866
+
1867
+ if len(missing_modules) > 0:
1868
+ raise ValueError(
1869
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
1870
+ )
1871
+
1872
+ new_pipeline = pipeline_class(**pipeline_kwargs)
1873
+ if pretrained_model_name_or_path is not None:
1874
+ new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
1875
+ new_pipeline.register_to_config(**unused_original_config)
1876
+
1877
+ if torch_dtype is not None:
1878
+ new_pipeline.to(dtype=torch_dtype)
1879
+
1880
+ return new_pipeline
1881
+
1653
1882
 
1654
1883
  class StableDiffusionMixin:
1655
1884
  r"""
@@ -1713,8 +1942,8 @@ class StableDiffusionMixin:
1713
1942
 
1714
1943
  def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1715
1944
  """
1716
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
1717
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1945
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
1946
+ are fused. For cross-attention modules, key and value projection matrices are fused.
1718
1947
 
1719
1948
  <Tip warning={true}>
1720
1949
 
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
25
  _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
26
+ _import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]
26
27
 
27
28
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
29
  try:
@@ -32,7 +33,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
32
33
  except OptionalDependencyNotAvailable:
33
34
  from ...utils.dummy_torch_and_transformers_objects import *
34
35
  else:
35
- from .pipeline_pixart_alpha import PixArtAlphaPipeline
36
+ from .pipeline_pixart_alpha import (
37
+ ASPECT_RATIO_256_BIN,
38
+ ASPECT_RATIO_512_BIN,
39
+ ASPECT_RATIO_1024_BIN,
40
+ PixArtAlphaPipeline,
41
+ )
42
+ from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline
36
43
 
37
44
  else:
38
45
  import sys