diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -22,15 +22,20 @@ from pathlib import Path
22
22
  from typing import Any, Dict, List, Optional, Union
23
23
 
24
24
  import torch
25
- from huggingface_hub import (
26
- model_info,
27
- )
25
+ from huggingface_hub import ModelCard, model_info
26
+ from huggingface_hub.utils import validate_hf_hub_args
28
27
  from packaging import version
29
28
 
29
+ from .. import __version__
30
30
  from ..utils import (
31
+ FLAX_WEIGHTS_NAME,
32
+ ONNX_EXTERNAL_WEIGHTS_NAME,
33
+ ONNX_WEIGHTS_NAME,
31
34
  SAFETENSORS_WEIGHTS_NAME,
32
35
  WEIGHTS_NAME,
36
+ deprecate,
33
37
  get_class_from_dynamic_module,
38
+ is_accelerate_available,
34
39
  is_peft_available,
35
40
  is_transformers_available,
36
41
  logging,
@@ -44,9 +49,12 @@ if is_transformers_available():
44
49
  from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
45
50
  from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
46
51
  from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
47
- from huggingface_hub.utils import validate_hf_hub_args
48
52
 
49
- from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
53
+ if is_accelerate_available():
54
+ import accelerate
55
+ from accelerate import dispatch_model
56
+ from accelerate.hooks import remove_hook_from_module
57
+ from accelerate.utils import compute_module_sizes, get_max_memory
50
58
 
51
59
 
52
60
  INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -82,49 +90,50 @@ for library in LOADABLE_CLASSES:
82
90
  ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
83
91
 
84
92
 
85
- def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
93
+ def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
86
94
  """
87
95
  Checking for safetensors compatibility:
88
- - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
89
- files to know which safetensors files are needed.
90
- - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
96
+ - The model is safetensors compatible only if there is a safetensors file for each model component present in
97
+ filenames.
91
98
 
92
99
  Converting default pytorch serialized filenames to safetensors serialized filenames:
93
100
  - For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
94
101
  - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
95
102
  extension is replaced with ".safetensors"
96
103
  """
97
- pt_filenames = []
98
-
99
- sf_filenames = set()
100
-
101
104
  passed_components = passed_components or []
105
+ if folder_names is not None:
106
+ filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
102
107
 
108
+ # extract all components of the pipeline and their associated files
109
+ components = {}
103
110
  for filename in filenames:
104
- _, extension = os.path.splitext(filename)
111
+ if not len(filename.split("/")) == 2:
112
+ continue
105
113
 
106
- if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
114
+ component, component_filename = filename.split("/")
115
+ if component in passed_components:
107
116
  continue
108
117
 
109
- if extension == ".bin":
110
- pt_filenames.append(os.path.normpath(filename))
111
- elif extension == ".safetensors":
112
- sf_filenames.add(os.path.normpath(filename))
118
+ components.setdefault(component, [])
119
+ components[component].append(component_filename)
113
120
 
114
- for filename in pt_filenames:
115
- # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
116
- path, filename = os.path.split(filename)
117
- filename, extension = os.path.splitext(filename)
121
+ # If there are no component folders check the main directory for safetensors files
122
+ if not components:
123
+ return any(".safetensors" in filename for filename in filenames)
118
124
 
119
- if filename.startswith("pytorch_model"):
120
- filename = filename.replace("pytorch_model", "model")
121
- else:
122
- filename = filename
125
+ # iterate over all files of a component
126
+ # check if safetensor files exist for that component
127
+ # if variant is provided check if the variant of the safetensors exists
128
+ for component, component_filenames in components.items():
129
+ matches = []
130
+ for component_filename in component_filenames:
131
+ filename, extension = os.path.splitext(component_filename)
132
+
133
+ match_exists = extension == ".safetensors"
134
+ matches.append(match_exists)
123
135
 
124
- expected_sf_filename = os.path.normpath(os.path.join(path, filename))
125
- expected_sf_filename = f"{expected_sf_filename}.safetensors"
126
- if expected_sf_filename not in sf_filenames:
127
- logger.warning(f"{expected_sf_filename} not found")
136
+ if not any(matches):
128
137
  return False
129
138
 
130
139
  return True
@@ -189,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
189
198
  variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
190
199
  return variant_filename
191
200
 
192
- for f in non_variant_filenames:
193
- variant_filename = convert_to_variant(f)
194
- if variant_filename not in usable_filenames:
195
- usable_filenames.add(f)
201
+ def find_component(filename):
202
+ if not len(filename.split("/")) == 2:
203
+ return
204
+ component = filename.split("/")[0]
205
+ return component
206
+
207
+ def has_sharded_variant(component, variant, variant_filenames):
208
+ # If component exists check for sharded variant index filename
209
+ # If component doesn't exist check main dir for sharded variant index filename
210
+ component = component + "/" if component else ""
211
+ variant_index_re = re.compile(
212
+ rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
+ )
214
+ return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215
+
216
+ for filename in non_variant_filenames:
217
+ if convert_to_variant(filename) in variant_filenames:
218
+ continue
219
+
220
+ component = find_component(filename)
221
+ # If a sharded variant exists skip adding to allowed patterns
222
+ if has_sharded_variant(component, variant, variant_filenames):
223
+ continue
224
+
225
+ usable_filenames.add(filename)
196
226
 
197
227
  return usable_filenames, variant_filenames
198
228
 
@@ -292,6 +322,39 @@ def get_class_obj_and_candidates(
292
322
  return class_obj, class_candidates
293
323
 
294
324
 
325
+ def _get_custom_pipeline_class(
326
+ custom_pipeline,
327
+ repo_id=None,
328
+ hub_revision=None,
329
+ class_name=None,
330
+ cache_dir=None,
331
+ revision=None,
332
+ ):
333
+ if custom_pipeline.endswith(".py"):
334
+ path = Path(custom_pipeline)
335
+ # decompose into folder & file
336
+ file_name = path.name
337
+ custom_pipeline = path.parent.absolute()
338
+ elif repo_id is not None:
339
+ file_name = f"{custom_pipeline}.py"
340
+ custom_pipeline = repo_id
341
+ else:
342
+ file_name = CUSTOM_PIPELINE_FILE_NAME
343
+
344
+ if repo_id is not None and hub_revision is not None:
345
+ # if we load the pipeline code from the Hub
346
+ # make sure to overwrite the `revision`
347
+ revision = hub_revision
348
+
349
+ return get_class_from_dynamic_module(
350
+ custom_pipeline,
351
+ module_file=file_name,
352
+ class_name=class_name,
353
+ cache_dir=cache_dir,
354
+ revision=revision,
355
+ )
356
+
357
+
295
358
  def _get_pipeline_class(
296
359
  class_obj,
297
360
  config=None,
@@ -304,25 +367,10 @@ def _get_pipeline_class(
304
367
  revision=None,
305
368
  ):
306
369
  if custom_pipeline is not None:
307
- if custom_pipeline.endswith(".py"):
308
- path = Path(custom_pipeline)
309
- # decompose into folder & file
310
- file_name = path.name
311
- custom_pipeline = path.parent.absolute()
312
- elif repo_id is not None:
313
- file_name = f"{custom_pipeline}.py"
314
- custom_pipeline = repo_id
315
- else:
316
- file_name = CUSTOM_PIPELINE_FILE_NAME
317
-
318
- if repo_id is not None and hub_revision is not None:
319
- # if we load the pipeline code from the Hub
320
- # make sure to overwrite the `revision`
321
- revision = hub_revision
322
-
323
- return get_class_from_dynamic_module(
370
+ return _get_custom_pipeline_class(
324
371
  custom_pipeline,
325
- module_file=file_name,
372
+ repo_id=repo_id,
373
+ hub_revision=hub_revision,
326
374
  class_name=class_name,
327
375
  cache_dir=cache_dir,
328
376
  revision=revision,
@@ -358,6 +406,206 @@ def _get_pipeline_class(
358
406
  return pipeline_cls
359
407
 
360
408
 
409
+ def _load_empty_model(
410
+ library_name: str,
411
+ class_name: str,
412
+ importable_classes: List[Any],
413
+ pipelines: Any,
414
+ is_pipeline_module: bool,
415
+ name: str,
416
+ torch_dtype: Union[str, torch.dtype],
417
+ cached_folder: Union[str, os.PathLike],
418
+ **kwargs,
419
+ ):
420
+ # retrieve class objects.
421
+ class_obj, _ = get_class_obj_and_candidates(
422
+ library_name,
423
+ class_name,
424
+ importable_classes,
425
+ pipelines,
426
+ is_pipeline_module,
427
+ component_name=name,
428
+ cache_dir=cached_folder,
429
+ )
430
+
431
+ if is_transformers_available():
432
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
433
+ else:
434
+ transformers_version = "N/A"
435
+
436
+ # Determine library.
437
+ is_transformers_model = (
438
+ is_transformers_available()
439
+ and issubclass(class_obj, PreTrainedModel)
440
+ and transformers_version >= version.parse("4.20.0")
441
+ )
442
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
443
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
444
+
445
+ model = None
446
+ config_path = cached_folder
447
+ user_agent = {
448
+ "diffusers": __version__,
449
+ "file_type": "model",
450
+ "framework": "pytorch",
451
+ }
452
+
453
+ if is_diffusers_model:
454
+ # Load config and then the model on meta.
455
+ config, unused_kwargs, commit_hash = class_obj.load_config(
456
+ os.path.join(config_path, name),
457
+ cache_dir=cached_folder,
458
+ return_unused_kwargs=True,
459
+ return_commit_hash=True,
460
+ force_download=kwargs.pop("force_download", False),
461
+ proxies=kwargs.pop("proxies", None),
462
+ local_files_only=kwargs.pop("local_files_only", False),
463
+ token=kwargs.pop("token", None),
464
+ revision=kwargs.pop("revision", None),
465
+ subfolder=kwargs.pop("subfolder", None),
466
+ user_agent=user_agent,
467
+ )
468
+ with accelerate.init_empty_weights():
469
+ model = class_obj.from_config(config, **unused_kwargs)
470
+ elif is_transformers_model:
471
+ config_class = getattr(class_obj, "config_class", None)
472
+ if config_class is None:
473
+ raise ValueError("`config_class` cannot be None. Please double-check the model.")
474
+
475
+ config = config_class.from_pretrained(
476
+ cached_folder,
477
+ subfolder=name,
478
+ force_download=kwargs.pop("force_download", False),
479
+ proxies=kwargs.pop("proxies", None),
480
+ local_files_only=kwargs.pop("local_files_only", False),
481
+ token=kwargs.pop("token", None),
482
+ revision=kwargs.pop("revision", None),
483
+ user_agent=user_agent,
484
+ )
485
+ with accelerate.init_empty_weights():
486
+ model = class_obj(config)
487
+
488
+ if model is not None:
489
+ model = model.to(dtype=torch_dtype)
490
+ return model
491
+
492
+
493
+ def _assign_components_to_devices(
494
+ module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
495
+ ):
496
+ device_ids = list(device_memory.keys())
497
+ device_cycle = device_ids + device_ids[::-1]
498
+ device_memory = device_memory.copy()
499
+
500
+ device_id_component_mapping = {}
501
+ current_device_index = 0
502
+ for component in module_sizes:
503
+ device_id = device_cycle[current_device_index % len(device_cycle)]
504
+ component_memory = module_sizes[component]
505
+ curr_device_memory = device_memory[device_id]
506
+
507
+ # If the GPU doesn't fit the current component offload to the CPU.
508
+ if component_memory > curr_device_memory:
509
+ device_id_component_mapping["cpu"] = [component]
510
+ else:
511
+ if device_id not in device_id_component_mapping:
512
+ device_id_component_mapping[device_id] = [component]
513
+ else:
514
+ device_id_component_mapping[device_id].append(component)
515
+
516
+ # Update the device memory.
517
+ device_memory[device_id] -= component_memory
518
+ current_device_index += 1
519
+
520
+ return device_id_component_mapping
521
+
522
+
523
+ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
524
+ # To avoid circular import problem.
525
+ from diffusers import pipelines
526
+
527
+ torch_dtype = kwargs.get("torch_dtype", torch.float32)
528
+
529
+ # Load each module in the pipeline on a meta device so that we can derive the device map.
530
+ init_empty_modules = {}
531
+ for name, (library_name, class_name) in init_dict.items():
532
+ if class_name.startswith("Flax"):
533
+ raise ValueError("Flax pipelines are not supported with `device_map`.")
534
+
535
+ # Define all importable classes
536
+ is_pipeline_module = hasattr(pipelines, library_name)
537
+ importable_classes = ALL_IMPORTABLE_CLASSES
538
+ loaded_sub_model = None
539
+
540
+ # Use passed sub model or load class_name from library_name
541
+ if name in passed_class_obj:
542
+ # if the model is in a pipeline module, then we load it from the pipeline
543
+ # check that passed_class_obj has correct parent class
544
+ maybe_raise_or_warn(
545
+ library_name,
546
+ library,
547
+ class_name,
548
+ importable_classes,
549
+ passed_class_obj,
550
+ name,
551
+ is_pipeline_module,
552
+ )
553
+ with accelerate.init_empty_weights():
554
+ loaded_sub_model = passed_class_obj[name]
555
+
556
+ else:
557
+ loaded_sub_model = _load_empty_model(
558
+ library_name=library_name,
559
+ class_name=class_name,
560
+ importable_classes=importable_classes,
561
+ pipelines=pipelines,
562
+ is_pipeline_module=is_pipeline_module,
563
+ pipeline_class=pipeline_class,
564
+ name=name,
565
+ torch_dtype=torch_dtype,
566
+ cached_folder=kwargs.get("cached_folder", None),
567
+ force_download=kwargs.get("force_download", None),
568
+ proxies=kwargs.get("proxies", None),
569
+ local_files_only=kwargs.get("local_files_only", None),
570
+ token=kwargs.get("token", None),
571
+ revision=kwargs.get("revision", None),
572
+ )
573
+
574
+ if loaded_sub_model is not None:
575
+ init_empty_modules[name] = loaded_sub_model
576
+
577
+ # determine device map
578
+ # Obtain a sorted dictionary for mapping the model-level components
579
+ # to their sizes.
580
+ module_sizes = {
581
+ module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
582
+ for module_name, module in init_empty_modules.items()
583
+ if isinstance(module, torch.nn.Module)
584
+ }
585
+ module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
586
+
587
+ # Obtain maximum memory available per device (GPUs only).
588
+ max_memory = get_max_memory(max_memory)
589
+ max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
590
+ max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
591
+
592
+ # Obtain a dictionary mapping the model-level components to the available
593
+ # devices based on the maximum memory and the model sizes.
594
+ final_device_map = None
595
+ if len(max_memory) > 0:
596
+ device_id_component_mapping = _assign_components_to_devices(
597
+ module_sizes, max_memory, device_mapping_strategy=device_map
598
+ )
599
+
600
+ # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
601
+ final_device_map = {}
602
+ for device_id, components in device_id_component_mapping.items():
603
+ for component in components:
604
+ final_device_map[component] = device_id
605
+
606
+ return final_device_map
607
+
608
+
361
609
  def load_sub_model(
362
610
  library_name: str,
363
611
  class_name: str,
@@ -378,9 +626,12 @@ def load_sub_model(
378
626
  variant: str,
379
627
  low_cpu_mem_usage: bool,
380
628
  cached_folder: Union[str, os.PathLike],
629
+ use_safetensors: bool,
381
630
  ):
382
631
  """Helper method to load the module `name` from `library_name` and `class_name`"""
632
+
383
633
  # retrieve class candidates
634
+
384
635
  class_obj, class_candidates = get_class_obj_and_candidates(
385
636
  library_name,
386
637
  class_name,
@@ -445,6 +696,7 @@ def load_sub_model(
445
696
  loading_kwargs["offload_folder"] = offload_folder
446
697
  loading_kwargs["offload_state_dict"] = offload_state_dict
447
698
  loading_kwargs["variant"] = model_variants.pop(name, None)
699
+ loading_kwargs["use_safetensors"] = use_safetensors
448
700
 
449
701
  if from_flax:
450
702
  loading_kwargs["from_flax"] = True
@@ -475,6 +727,22 @@ def load_sub_model(
475
727
  # else load from the root directory
476
728
  loaded_sub_model = load_method(cached_folder, **loading_kwargs)
477
729
 
730
+ if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
731
+ # remove hooks
732
+ remove_hook_from_module(loaded_sub_model, recurse=True)
733
+ needs_offloading_to_cpu = device_map[""] == "cpu"
734
+
735
+ if needs_offloading_to_cpu:
736
+ dispatch_model(
737
+ loaded_sub_model,
738
+ state_dict=loaded_sub_model.state_dict(),
739
+ device_map=device_map,
740
+ force_hooks=True,
741
+ main_device=0,
742
+ )
743
+ else:
744
+ dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
745
+
478
746
  return loaded_sub_model
479
747
 
480
748
 
@@ -506,3 +774,197 @@ def _fetch_class_library_tuple(module):
506
774
  class_name = not_compiled_module.__class__.__name__
507
775
 
508
776
  return (library, class_name)
777
+
778
+
779
+ def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
780
+ model_variants = {}
781
+ if variant is not None:
782
+ for sub_folder in os.listdir(folder):
783
+ folder_path = os.path.join(folder, sub_folder)
784
+ is_folder = os.path.isdir(folder_path) and sub_folder in config
785
+ variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
786
+ if variant_exists:
787
+ model_variants[sub_folder] = variant
788
+ return model_variants
789
+
790
+
791
+ def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
792
+ custom_class_name = None
793
+ if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
794
+ custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
795
+ elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
796
+ os.path.join(folder, f"{config['_class_name'][0]}.py")
797
+ ):
798
+ custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
799
+ custom_class_name = config["_class_name"][1]
800
+
801
+ return custom_pipeline, custom_class_name
802
+
803
+
804
+ def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
805
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
806
+ version.parse(config["_diffusers_version"]).base_version
807
+ ) <= version.parse("0.5.1"):
808
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
809
+
810
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
811
+
812
+ deprecation_message = (
813
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
814
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
815
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
816
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
817
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
818
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
819
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
820
+ )
821
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
822
+
823
+
824
+ def _update_init_kwargs_with_connected_pipeline(
825
+ init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
826
+ ) -> dict:
827
+ from .pipeline_utils import DiffusionPipeline
828
+
829
+ modelcard = ModelCard.load(os.path.join(folder, "README.md"))
830
+ connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
831
+
832
+ # We don't scheduler argument to match the existing logic:
833
+ # https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
834
+ pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
835
+ if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
836
+ for k in pipeline_loading_kwargs:
837
+ if "scheduler" in k:
838
+ _ = pipeline_loading_kwargs_cp.pop(k)
839
+
840
+ def get_connected_passed_kwargs(prefix):
841
+ connected_passed_class_obj = {
842
+ k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
843
+ }
844
+ connected_passed_pipe_kwargs = {
845
+ k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
846
+ }
847
+
848
+ connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
849
+ return connected_passed_kwargs
850
+
851
+ connected_pipes = {
852
+ prefix: DiffusionPipeline.from_pretrained(
853
+ repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
854
+ )
855
+ for prefix, repo_id in connected_pipes.items()
856
+ if repo_id is not None
857
+ }
858
+
859
+ for prefix, connected_pipe in connected_pipes.items():
860
+ # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
861
+ init_kwargs.update(
862
+ {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
863
+ )
864
+
865
+ return init_kwargs
866
+
867
+
868
+ def _get_custom_components_and_folders(
869
+ pretrained_model_name: str,
870
+ config_dict: Dict[str, Any],
871
+ filenames: Optional[List[str]] = None,
872
+ variant_filenames: Optional[List[str]] = None,
873
+ variant: Optional[str] = None,
874
+ ):
875
+ config_dict = config_dict.copy()
876
+
877
+ # retrieve all folder_names that contain relevant files
878
+ folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
879
+
880
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
881
+ pipelines = getattr(diffusers_module, "pipelines")
882
+
883
+ # optionally create a custom component <> custom file mapping
884
+ custom_components = {}
885
+ for component in folder_names:
886
+ module_candidate = config_dict[component][0]
887
+
888
+ if module_candidate is None or not isinstance(module_candidate, str):
889
+ continue
890
+
891
+ # We compute candidate file path on the Hub. Do not use `os.path.join`.
892
+ candidate_file = f"{component}/{module_candidate}.py"
893
+
894
+ if candidate_file in filenames:
895
+ custom_components[component] = module_candidate
896
+ elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
897
+ raise ValueError(
898
+ f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
899
+ )
900
+
901
+ if len(variant_filenames) == 0 and variant is not None:
902
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
903
+ raise ValueError(error_message)
904
+
905
+ return custom_components, folder_names
906
+
907
+
908
+ def _get_ignore_patterns(
909
+ passed_components,
910
+ model_folder_names: List[str],
911
+ model_filenames: List[str],
912
+ variant_filenames: List[str],
913
+ use_safetensors: bool,
914
+ from_flax: bool,
915
+ allow_pickle: bool,
916
+ use_onnx: bool,
917
+ is_onnx: bool,
918
+ variant: Optional[str] = None,
919
+ ) -> List[str]:
920
+ if (
921
+ use_safetensors
922
+ and not allow_pickle
923
+ and not is_safetensors_compatible(
924
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
925
+ )
926
+ ):
927
+ raise EnvironmentError(
928
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
929
+ )
930
+
931
+ if from_flax:
932
+ ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
933
+
934
+ elif use_safetensors and is_safetensors_compatible(
935
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
936
+ ):
937
+ ignore_patterns = ["*.bin", "*.msgpack"]
938
+
939
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
940
+ if not use_onnx:
941
+ ignore_patterns += ["*.onnx", "*.pb"]
942
+
943
+ safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
944
+ safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
945
+ if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
946
+ logger.warning(
947
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
948
+ f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
949
+ f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
950
+ f"expected, please check your folder structure."
951
+ )
952
+
953
+ else:
954
+ ignore_patterns = ["*.safetensors", "*.msgpack"]
955
+
956
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
957
+ if not use_onnx:
958
+ ignore_patterns += ["*.onnx", "*.pb"]
959
+
960
+ bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
961
+ bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
962
+ if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
963
+ logger.warning(
964
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
965
+ f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
966
+ f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
967
+ f"your folder structure."
968
+ )
969
+
970
+ return ignore_patterns