diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.27.2"
1
+ __version__ = "0.28.1"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -27,6 +27,7 @@ from .utils import (
27
27
 
28
28
  _import_structure = {
29
29
  "configuration_utils": ["ConfigMixin"],
30
+ "loaders": ["FromOriginalModelMixin"],
30
31
  "models": [],
31
32
  "pipelines": [],
32
33
  "schedulers": [],
@@ -80,11 +81,15 @@ else:
80
81
  "AutoencoderTiny",
81
82
  "ConsistencyDecoderVAE",
82
83
  "ControlNetModel",
84
+ "ControlNetXSAdapter",
85
+ "DiTTransformer2DModel",
86
+ "HunyuanDiT2DModel",
83
87
  "I2VGenXLUNet",
84
88
  "Kandinsky3UNet",
85
89
  "ModelMixin",
86
90
  "MotionAdapter",
87
91
  "MultiAdapter",
92
+ "PixArtTransformer2DModel",
88
93
  "PriorTransformer",
89
94
  "StableCascadeUNet",
90
95
  "T2IAdapter",
@@ -94,6 +99,7 @@ else:
94
99
  "UNet2DConditionModel",
95
100
  "UNet2DModel",
96
101
  "UNet3DConditionModel",
102
+ "UNetControlNetXSModel",
97
103
  "UNetMotionModel",
98
104
  "UNetSpatioTemporalConditionModel",
99
105
  "UVit2DModel",
@@ -214,6 +220,7 @@ else:
214
220
  "AmusedInpaintPipeline",
215
221
  "AmusedPipeline",
216
222
  "AnimateDiffPipeline",
223
+ "AnimateDiffSDXLPipeline",
217
224
  "AnimateDiffVideoToVideoPipeline",
218
225
  "AudioLDM2Pipeline",
219
226
  "AudioLDM2ProjectionModel",
@@ -223,6 +230,7 @@ else:
223
230
  "BlipDiffusionPipeline",
224
231
  "CLIPImageProjection",
225
232
  "CycleDiffusionPipeline",
233
+ "HunyuanDiTPipeline",
226
234
  "I2VGenXLPipeline",
227
235
  "IFImg2ImgPipeline",
228
236
  "IFImg2ImgSuperResolutionPipeline",
@@ -255,10 +263,13 @@ else:
255
263
  "LDMTextToImagePipeline",
256
264
  "LEditsPPPipelineStableDiffusion",
257
265
  "LEditsPPPipelineStableDiffusionXL",
266
+ "MarigoldDepthPipeline",
267
+ "MarigoldNormalsPipeline",
258
268
  "MusicLDMPipeline",
259
269
  "PaintByExamplePipeline",
260
270
  "PIAPipeline",
261
271
  "PixArtAlphaPipeline",
272
+ "PixArtSigmaPipeline",
262
273
  "SemanticStableDiffusionPipeline",
263
274
  "ShapEImg2ImgPipeline",
264
275
  "ShapEPipeline",
@@ -270,6 +281,7 @@ else:
270
281
  "StableDiffusionControlNetImg2ImgPipeline",
271
282
  "StableDiffusionControlNetInpaintPipeline",
272
283
  "StableDiffusionControlNetPipeline",
284
+ "StableDiffusionControlNetXSPipeline",
273
285
  "StableDiffusionDepth2ImgPipeline",
274
286
  "StableDiffusionDiffEditPipeline",
275
287
  "StableDiffusionGLIGENPipeline",
@@ -293,6 +305,7 @@ else:
293
305
  "StableDiffusionXLControlNetImg2ImgPipeline",
294
306
  "StableDiffusionXLControlNetInpaintPipeline",
295
307
  "StableDiffusionXLControlNetPipeline",
308
+ "StableDiffusionXLControlNetXSPipeline",
296
309
  "StableDiffusionXLImg2ImgPipeline",
297
310
  "StableDiffusionXLInpaintPipeline",
298
311
  "StableDiffusionXLInstructPix2PixPipeline",
@@ -474,11 +487,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
474
487
  AutoencoderTiny,
475
488
  ConsistencyDecoderVAE,
476
489
  ControlNetModel,
490
+ ControlNetXSAdapter,
491
+ DiTTransformer2DModel,
492
+ HunyuanDiT2DModel,
477
493
  I2VGenXLUNet,
478
494
  Kandinsky3UNet,
479
495
  ModelMixin,
480
496
  MotionAdapter,
481
497
  MultiAdapter,
498
+ PixArtTransformer2DModel,
482
499
  PriorTransformer,
483
500
  T2IAdapter,
484
501
  T5FilmDecoder,
@@ -487,6 +504,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
487
504
  UNet2DConditionModel,
488
505
  UNet2DModel,
489
506
  UNet3DConditionModel,
507
+ UNetControlNetXSModel,
490
508
  UNetMotionModel,
491
509
  UNetSpatioTemporalConditionModel,
492
510
  UVit2DModel,
@@ -588,6 +606,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
588
606
  AmusedInpaintPipeline,
589
607
  AmusedPipeline,
590
608
  AnimateDiffPipeline,
609
+ AnimateDiffSDXLPipeline,
591
610
  AnimateDiffVideoToVideoPipeline,
592
611
  AudioLDM2Pipeline,
593
612
  AudioLDM2ProjectionModel,
@@ -595,6 +614,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
595
614
  AudioLDMPipeline,
596
615
  CLIPImageProjection,
597
616
  CycleDiffusionPipeline,
617
+ HunyuanDiTPipeline,
598
618
  I2VGenXLPipeline,
599
619
  IFImg2ImgPipeline,
600
620
  IFImg2ImgSuperResolutionPipeline,
@@ -627,10 +647,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
627
647
  LDMTextToImagePipeline,
628
648
  LEditsPPPipelineStableDiffusion,
629
649
  LEditsPPPipelineStableDiffusionXL,
650
+ MarigoldDepthPipeline,
651
+ MarigoldNormalsPipeline,
630
652
  MusicLDMPipeline,
631
653
  PaintByExamplePipeline,
632
654
  PIAPipeline,
633
655
  PixArtAlphaPipeline,
656
+ PixArtSigmaPipeline,
634
657
  SemanticStableDiffusionPipeline,
635
658
  ShapEImg2ImgPipeline,
636
659
  ShapEPipeline,
@@ -642,6 +665,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
642
665
  StableDiffusionControlNetImg2ImgPipeline,
643
666
  StableDiffusionControlNetInpaintPipeline,
644
667
  StableDiffusionControlNetPipeline,
668
+ StableDiffusionControlNetXSPipeline,
645
669
  StableDiffusionDepth2ImgPipeline,
646
670
  StableDiffusionDiffEditPipeline,
647
671
  StableDiffusionGLIGENPipeline,
@@ -665,6 +689,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
665
689
  StableDiffusionXLControlNetImg2ImgPipeline,
666
690
  StableDiffusionXLControlNetInpaintPipeline,
667
691
  StableDiffusionXLControlNetPipeline,
692
+ StableDiffusionXLControlNetXSPipeline,
668
693
  StableDiffusionXLImg2ImgPipeline,
669
694
  StableDiffusionXLInpaintPipeline,
670
695
  StableDiffusionXLInstructPix2PixPipeline,
diffusers/callbacks.py ADDED
@@ -0,0 +1,156 @@
1
+ from typing import Any, Dict, List
2
+
3
+ from .configuration_utils import ConfigMixin, register_to_config
4
+ from .utils import CONFIG_NAME
5
+
6
+
7
+ class PipelineCallback(ConfigMixin):
8
+ """
9
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
10
+ custom callbacks and ensures that all callbacks have a consistent interface.
11
+
12
+ Please implement the following:
13
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
14
+ include
15
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
16
+ `callback_fn`: This method defines the core functionality of your callback.
17
+ """
18
+
19
+ config_name = CONFIG_NAME
20
+
21
+ @register_to_config
22
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
23
+ super().__init__()
24
+
25
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
26
+ cutoff_step_ratio is not None and cutoff_step_index is not None
27
+ ):
28
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
29
+
30
+ if cutoff_step_ratio is not None and (
31
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
32
+ ):
33
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
34
+
35
+ @property
36
+ def tensor_inputs(self) -> List[str]:
37
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
38
+
39
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
40
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
41
+
42
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
43
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
44
+
45
+
46
+ class MultiPipelineCallbacks:
47
+ """
48
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
49
+ provides a unified interface for calling all of them.
50
+ """
51
+
52
+ def __init__(self, callbacks: List[PipelineCallback]):
53
+ self.callbacks = callbacks
54
+
55
+ @property
56
+ def tensor_inputs(self) -> List[str]:
57
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
58
+
59
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
60
+ """
61
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
62
+ """
63
+ for callback in self.callbacks:
64
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
65
+
66
+ return callback_kwargs
67
+
68
+
69
+ class SDCFGCutoffCallback(PipelineCallback):
70
+ """
71
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
72
+ `cutoff_step_index`), this callback will disable the CFG.
73
+
74
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
75
+ """
76
+
77
+ tensor_inputs = ["prompt_embeds"]
78
+
79
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
80
+ cutoff_step_ratio = self.config.cutoff_step_ratio
81
+ cutoff_step_index = self.config.cutoff_step_index
82
+
83
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
84
+ cutoff_step = (
85
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
86
+ )
87
+
88
+ if step_index == cutoff_step:
89
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
90
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
91
+
92
+ pipeline._guidance_scale = 0.0
93
+
94
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
95
+ return callback_kwargs
96
+
97
+
98
+ class SDXLCFGCutoffCallback(PipelineCallback):
99
+ """
100
+ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101
+ `cutoff_step_index`), this callback will disable the CFG.
102
+
103
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
+ """
105
+
106
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
107
+
108
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109
+ cutoff_step_ratio = self.config.cutoff_step_ratio
110
+ cutoff_step_index = self.config.cutoff_step_index
111
+
112
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
113
+ cutoff_step = (
114
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
115
+ )
116
+
117
+ if step_index == cutoff_step:
118
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
119
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
120
+
121
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
122
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
123
+
124
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
125
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
126
+
127
+ pipeline._guidance_scale = 0.0
128
+
129
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
132
+ return callback_kwargs
133
+
134
+
135
+ class IPAdapterScaleCutoffCallback(PipelineCallback):
136
+ """
137
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
138
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
139
+
140
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
141
+ """
142
+
143
+ tensor_inputs = []
144
+
145
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
146
+ cutoff_step_ratio = self.config.cutoff_step_ratio
147
+ cutoff_step_index = self.config.cutoff_step_index
148
+
149
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
150
+ cutoff_step = (
151
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
152
+ )
153
+
154
+ if step_index == cutoff_step:
155
+ pipeline.set_ip_adapter_scale(0.0)
156
+ return callback_kwargs
diffusers/commands/env.py CHANGED
@@ -13,12 +13,25 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import platform
16
+ import subprocess
16
17
  from argparse import ArgumentParser
17
18
 
18
19
  import huggingface_hub
19
20
 
20
21
  from .. import __version__ as version
21
- from ..utils import is_accelerate_available, is_torch_available, is_transformers_available, is_xformers_available
22
+ from ..utils import (
23
+ is_accelerate_available,
24
+ is_bitsandbytes_available,
25
+ is_flax_available,
26
+ is_google_colab,
27
+ is_notebook,
28
+ is_peft_available,
29
+ is_safetensors_available,
30
+ is_torch_available,
31
+ is_transformers_available,
32
+ is_xformers_available,
33
+ )
34
+ from ..utils.testing_utils import get_python_version
22
35
  from . import BaseDiffusersCLICommand
23
36
 
24
37
 
@@ -28,13 +41,19 @@ def info_command_factory(_):
28
41
 
29
42
  class EnvironmentCommand(BaseDiffusersCLICommand):
30
43
  @staticmethod
31
- def register_subcommand(parser: ArgumentParser):
44
+ def register_subcommand(parser: ArgumentParser) -> None:
32
45
  download_parser = parser.add_parser("env")
33
46
  download_parser.set_defaults(func=info_command_factory)
34
47
 
35
- def run(self):
48
+ def run(self) -> dict:
36
49
  hub_version = huggingface_hub.__version__
37
50
 
51
+ safetensors_version = "not installed"
52
+ if is_safetensors_available():
53
+ import safetensors
54
+
55
+ safetensors_version = safetensors.__version__
56
+
38
57
  pt_version = "not installed"
39
58
  pt_cuda_available = "NA"
40
59
  if is_torch_available():
@@ -43,6 +62,20 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
43
62
  pt_version = torch.__version__
44
63
  pt_cuda_available = torch.cuda.is_available()
45
64
 
65
+ flax_version = "not installed"
66
+ jax_version = "not installed"
67
+ jaxlib_version = "not installed"
68
+ jax_backend = "NA"
69
+ if is_flax_available():
70
+ import flax
71
+ import jax
72
+ import jaxlib
73
+
74
+ flax_version = flax.__version__
75
+ jax_version = jax.__version__
76
+ jaxlib_version = jaxlib.__version__
77
+ jax_backend = jax.lib.xla_bridge.get_backend().platform
78
+
46
79
  transformers_version = "not installed"
47
80
  if is_transformers_available():
48
81
  import transformers
@@ -55,21 +88,92 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
55
88
 
56
89
  accelerate_version = accelerate.__version__
57
90
 
91
+ peft_version = "not installed"
92
+ if is_peft_available():
93
+ import peft
94
+
95
+ peft_version = peft.__version__
96
+
97
+ bitsandbytes_version = "not installed"
98
+ if is_bitsandbytes_available():
99
+ import bitsandbytes
100
+
101
+ bitsandbytes_version = bitsandbytes.__version__
102
+
58
103
  xformers_version = "not installed"
59
104
  if is_xformers_available():
60
105
  import xformers
61
106
 
62
107
  xformers_version = xformers.__version__
63
108
 
109
+ if get_python_version() >= (3, 10):
110
+ platform_info = f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}"
111
+ else:
112
+ platform_info = platform.platform()
113
+
114
+ is_notebook_str = "Yes" if is_notebook() else "No"
115
+
116
+ is_google_colab_str = "Yes" if is_google_colab() else "No"
117
+
118
+ accelerator = "NA"
119
+ if platform.system() in {"Linux", "Windows"}:
120
+ try:
121
+ sp = subprocess.Popen(
122
+ ["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
123
+ stdout=subprocess.PIPE,
124
+ stderr=subprocess.PIPE,
125
+ )
126
+ out_str, _ = sp.communicate()
127
+ out_str = out_str.decode("utf-8")
128
+
129
+ if len(out_str) > 0:
130
+ accelerator = out_str.strip() + " VRAM"
131
+ except FileNotFoundError:
132
+ pass
133
+ elif platform.system() == "Darwin": # Mac OS
134
+ try:
135
+ sp = subprocess.Popen(
136
+ ["system_profiler", "SPDisplaysDataType"],
137
+ stdout=subprocess.PIPE,
138
+ stderr=subprocess.PIPE,
139
+ )
140
+ out_str, _ = sp.communicate()
141
+ out_str = out_str.decode("utf-8")
142
+
143
+ start = out_str.find("Chipset Model:")
144
+ if start != -1:
145
+ start += len("Chipset Model:")
146
+ end = out_str.find("\n", start)
147
+ accelerator = out_str[start:end].strip()
148
+
149
+ start = out_str.find("VRAM (Total):")
150
+ if start != -1:
151
+ start += len("VRAM (Total):")
152
+ end = out_str.find("\n", start)
153
+ accelerator += " VRAM: " + out_str[start:end].strip()
154
+ except FileNotFoundError:
155
+ pass
156
+ else:
157
+ print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
158
+
64
159
  info = {
65
- "`diffusers` version": version,
66
- "Platform": platform.platform(),
160
+ "🤗 Diffusers version": version,
161
+ "Platform": platform_info,
162
+ "Running on a notebook?": is_notebook_str,
163
+ "Running on Google Colab?": is_google_colab_str,
67
164
  "Python version": platform.python_version(),
68
165
  "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
166
+ "Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
167
+ "Jax version": jax_version,
168
+ "JaxLib version": jaxlib_version,
69
169
  "Huggingface_hub version": hub_version,
70
170
  "Transformers version": transformers_version,
71
171
  "Accelerate version": accelerate_version,
172
+ "PEFT version": peft_version,
173
+ "Bitsandbytes version": bitsandbytes_version,
174
+ "Safetensors version": safetensors_version,
72
175
  "xFormers version": xformers_version,
176
+ "Accelerator": accelerator,
73
177
  "Using GPU in script?": "<fill in>",
74
178
  "Using distributed or parallel set-up in script?": "<fill in>",
75
179
  }
@@ -80,5 +184,5 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
80
184
  return info
81
185
 
82
186
  @staticmethod
83
- def format_dict(d):
187
+ def format_dict(d: dict) -> str:
84
188
  return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
@@ -13,7 +13,8 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
- """ ConfigMixin base class and utilities."""
16
+ """ConfigMixin base class and utilities."""
17
+
17
18
  import dataclasses
18
19
  import functools
19
20
  import importlib
@@ -309,9 +310,9 @@ class ConfigMixin:
309
310
  force_download (`bool`, *optional*, defaults to `False`):
310
311
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
311
312
  cached versions if they exist.
312
- resume_download (`bool`, *optional*, defaults to `False`):
313
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
314
- incompletely downloaded files are deleted.
313
+ resume_download:
314
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
315
+ of Diffusers.
315
316
  proxies (`Dict[str, str]`, *optional*):
316
317
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
317
318
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -339,8 +340,10 @@ class ConfigMixin:
339
340
 
340
341
  """
341
342
  cache_dir = kwargs.pop("cache_dir", None)
343
+ local_dir = kwargs.pop("local_dir", None)
344
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
342
345
  force_download = kwargs.pop("force_download", False)
343
- resume_download = kwargs.pop("resume_download", False)
346
+ resume_download = kwargs.pop("resume_download", None)
344
347
  proxies = kwargs.pop("proxies", None)
345
348
  token = kwargs.pop("token", None)
346
349
  local_files_only = kwargs.pop("local_files_only", False)
@@ -363,13 +366,13 @@ class ConfigMixin:
363
366
  if os.path.isfile(pretrained_model_name_or_path):
364
367
  config_file = pretrained_model_name_or_path
365
368
  elif os.path.isdir(pretrained_model_name_or_path):
366
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
367
- # Load from a PyTorch checkpoint
368
- config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
369
- elif subfolder is not None and os.path.isfile(
369
+ if subfolder is not None and os.path.isfile(
370
370
  os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
371
371
  ):
372
372
  config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
373
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
374
+ # Load from a PyTorch checkpoint
375
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
373
376
  else:
374
377
  raise EnvironmentError(
375
378
  f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
@@ -389,6 +392,8 @@ class ConfigMixin:
389
392
  user_agent=user_agent,
390
393
  subfolder=subfolder,
391
394
  revision=revision,
395
+ local_dir=local_dir,
396
+ local_dir_use_symlinks=local_dir_use_symlinks,
392
397
  )
393
398
  except RepositoryNotFoundError:
394
399
  raise EnvironmentError(
@@ -449,8 +454,8 @@ class ConfigMixin:
449
454
  return outputs
450
455
 
451
456
  @staticmethod
452
- def _get_init_keys(cls):
453
- return set(dict(inspect.signature(cls.__init__).parameters).keys())
457
+ def _get_init_keys(input_class):
458
+ return set(dict(inspect.signature(input_class.__init__).parameters).keys())
454
459
 
455
460
  @classmethod
456
461
  def extract_init_dict(cls, config_dict, **kwargs):
@@ -701,3 +706,20 @@ def flax_register_to_config(cls):
701
706
 
702
707
  cls.__init__ = init
703
708
  return cls
709
+
710
+
711
+ class LegacyConfigMixin(ConfigMixin):
712
+ r"""
713
+ A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
714
+ pipeline-specific classes (like `DiTTransformer2DModel`).
715
+ """
716
+
717
+ @classmethod
718
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
719
+ # To prevent depedency import problem.
720
+ from .models.model_loading_utils import _fetch_remapped_cls_from_config
721
+
722
+ # resolve remapping
723
+ remapped_class = _fetch_remapped_cls_from_config(config, cls)
724
+
725
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
@@ -3,7 +3,7 @@
3
3
  # 2. run `make deps_table_update`
4
4
  deps = {
5
5
  "Pillow": "Pillow",
6
- "accelerate": "accelerate>=0.11.0",
6
+ "accelerate": "accelerate>=0.29.3",
7
7
  "compel": "compel==0.1.8",
8
8
  "datasets": "datasets",
9
9
  "filelock": "filelock",
@@ -42,4 +42,5 @@ deps = {
42
42
  "torchvision": "torchvision",
43
43
  "transformers": "transformers>=4.25.1",
44
44
  "urllib3": "urllib3<=2.0.0",
45
+ "black": "black",
45
46
  }