diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +50 -53
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.0.dist-info/RECORD +0 -399
  443. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
+ import enum
16
17
  import fnmatch
17
18
  import importlib
18
19
  import inspect
@@ -21,7 +22,7 @@ import re
21
22
  import sys
22
23
  from dataclasses import dataclass
23
24
  from pathlib import Path
24
- from typing import Any, Callable, Dict, List, Optional, Union
25
+ from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
25
26
 
26
27
  import numpy as np
27
28
  import PIL.Image
@@ -43,38 +44,45 @@ from .. import __version__
43
44
  from ..configuration_utils import ConfigMixin
44
45
  from ..models import AutoencoderKL
45
46
  from ..models.attention_processor import FusedAttnProcessor2_0
46
- from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
47
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
48
+ from ..quantizers.bitsandbytes.utils import _check_bnb_status
47
49
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
48
50
  from ..utils import (
49
51
  CONFIG_NAME,
50
52
  DEPRECATED_REVISION_ARGS,
51
53
  BaseOutput,
52
54
  PushToHubMixin,
53
- deprecate,
54
55
  is_accelerate_available,
55
56
  is_accelerate_version,
56
57
  is_torch_npu_available,
57
58
  is_torch_version,
59
+ is_transformers_version,
58
60
  logging,
59
61
  numpy_to_pil,
60
62
  )
61
- from ..utils.hub_utils import load_or_create_model_card, populate_model_card
63
+ from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
62
64
  from ..utils.torch_utils import is_compiled_module
63
65
 
64
66
 
65
67
  if is_torch_npu_available():
66
68
  import torch_npu # noqa: F401
67
69
 
68
-
69
70
  from .pipeline_loading_utils import (
70
71
  ALL_IMPORTABLE_CLASSES,
71
72
  CONNECTED_PIPES_KEYS,
72
73
  CUSTOM_PIPELINE_FILE_NAME,
73
74
  LOADABLE_CLASSES,
74
75
  _fetch_class_library_tuple,
76
+ _get_custom_components_and_folders,
77
+ _get_custom_pipeline_class,
78
+ _get_final_device_map,
79
+ _get_ignore_patterns,
75
80
  _get_pipeline_class,
81
+ _identify_model_variants,
82
+ _maybe_raise_warning_for_inpainting,
83
+ _resolve_custom_pipeline_and_cls,
76
84
  _unwrap_model,
77
- is_safetensors_compatible,
85
+ _update_init_kwargs_with_connected_pipeline,
78
86
  load_sub_model,
79
87
  maybe_raise_or_warn,
80
88
  variant_compatible_siblings,
@@ -90,6 +98,8 @@ LIBRARIES = []
90
98
  for library in LOADABLE_CLASSES:
91
99
  LIBRARIES.append(library)
92
100
 
101
+ SUPPORTED_DEVICE_MAP = ["balanced"]
102
+
93
103
  logger = logging.get_logger(__name__)
94
104
 
95
105
 
@@ -140,6 +150,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
140
150
 
141
151
  config_name = "model_index.json"
142
152
  model_cpu_offload_seq = None
153
+ hf_device_map = None
143
154
  _optional_components = []
144
155
  _exclude_from_cpu_offload = []
145
156
  _load_connected_pipes = False
@@ -180,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
180
191
  save_directory: Union[str, os.PathLike],
181
192
  safe_serialization: bool = True,
182
193
  variant: Optional[str] = None,
194
+ max_shard_size: Optional[Union[int, str]] = None,
183
195
  push_to_hub: bool = False,
184
196
  **kwargs,
185
197
  ):
@@ -195,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
195
207
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
196
208
  variant (`str`, *optional*):
197
209
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
210
+ max_shard_size (`int` or `str`, defaults to `None`):
211
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
212
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
213
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
214
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
215
+ This is to establish a common default size for this argument across different libraries in the Hugging
216
+ Face ecosystem (`transformers`, and `accelerate`, for example).
198
217
  push_to_hub (`bool`, *optional*, defaults to `False`):
199
218
  Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
200
219
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -210,7 +229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
210
229
 
211
230
  if push_to_hub:
212
231
  commit_message = kwargs.pop("commit_message", None)
213
- private = kwargs.pop("private", False)
232
+ private = kwargs.pop("private", None)
214
233
  create_pr = kwargs.pop("create_pr", False)
215
234
  token = kwargs.pop("token", None)
216
235
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -269,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
269
288
  save_method_signature = inspect.signature(save_method)
270
289
  save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
271
290
  save_method_accept_variant = "variant" in save_method_signature.parameters
291
+ save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
272
292
 
273
293
  save_kwargs = {}
274
294
  if save_method_accept_safe:
275
295
  save_kwargs["safe_serialization"] = safe_serialization
276
296
  if save_method_accept_variant:
277
297
  save_kwargs["variant"] = variant
298
+ if save_method_accept_max_shard_size and max_shard_size is not None:
299
+ # max_shard_size is expected to not be None in ModelMixin
300
+ save_kwargs["max_shard_size"] = max_shard_size
278
301
 
279
302
  save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
280
303
 
@@ -365,14 +388,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
365
388
  )
366
389
 
367
390
  device = device or device_arg
391
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
368
392
 
369
393
  # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
370
394
  def module_is_sequentially_offloaded(module):
371
395
  if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
372
396
  return False
373
397
 
374
- return hasattr(module, "_hf_hook") and not isinstance(
375
- module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
398
+ return hasattr(module, "_hf_hook") and (
399
+ isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
400
+ or hasattr(module._hf_hook, "hooks")
401
+ and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
376
402
  )
377
403
 
378
404
  def module_is_offloaded(module):
@@ -385,9 +411,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
385
411
  pipeline_is_sequentially_offloaded = any(
386
412
  module_is_sequentially_offloaded(module) for _, module in self.components.items()
387
413
  )
388
- if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
414
+ if device and torch.device(device).type == "cuda":
415
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
416
+ raise ValueError(
417
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
418
+ )
419
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
420
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
421
+ raise ValueError(
422
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
423
+ )
424
+
425
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
426
+ if is_pipeline_device_mapped:
389
427
  raise ValueError(
390
- "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
428
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
391
429
  )
392
430
 
393
431
  # Display a warning in this case (the operation succeeds but the benefits are lost)
@@ -403,18 +441,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
403
441
 
404
442
  is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
405
443
  for module in modules:
406
- is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
444
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
407
445
 
408
- if is_loaded_in_8bit and dtype is not None:
446
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
409
447
  logger.warning(
410
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
448
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
411
449
  )
412
450
 
413
- if is_loaded_in_8bit and device is not None:
451
+ if is_loaded_in_8bit_bnb and device is not None:
414
452
  logger.warning(
415
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
453
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
416
454
  )
417
- else:
455
+
456
+ # This can happen for `transformer` models. CPU placement was added in
457
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
458
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
459
+ module.to(device=device)
460
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
418
461
  module.to(device, dtype)
419
462
 
420
463
  if (
@@ -520,9 +563,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
520
563
  cache_dir (`Union[str, os.PathLike]`, *optional*):
521
564
  Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
522
565
  is not used.
523
- resume_download (`bool`, *optional*, defaults to `False`):
524
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
525
- incompletely downloaded files are deleted.
566
+
526
567
  proxies (`Dict[str, str]`, *optional*):
527
568
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
528
569
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -539,7 +580,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
539
580
  allowed by Git.
540
581
  custom_revision (`str`, *optional*):
541
582
  The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
542
- `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
583
+ `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
584
+ version.
543
585
  mirror (`str`, *optional*):
544
586
  Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
545
587
  guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -610,8 +652,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
610
652
  >>> pipeline.scheduler = scheduler
611
653
  ```
612
654
  """
655
+ # Copy the kwargs to re-use during loading connected pipeline.
656
+ kwargs_copied = kwargs.copy()
657
+
613
658
  cache_dir = kwargs.pop("cache_dir", None)
614
- resume_download = kwargs.pop("resume_download", False)
615
659
  force_download = kwargs.pop("force_download", False)
616
660
  proxies = kwargs.pop("proxies", None)
617
661
  local_files_only = kwargs.pop("local_files_only", None)
@@ -642,18 +686,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
642
686
  " install accelerate\n```\n."
643
687
  )
644
688
 
689
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
690
+ raise NotImplementedError(
691
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
692
+ " `low_cpu_mem_usage=False`."
693
+ )
694
+
645
695
  if device_map is not None and not is_torch_version(">=", "1.9.0"):
646
696
  raise NotImplementedError(
647
697
  "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
648
698
  " `device_map=None`."
649
699
  )
650
700
 
651
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
701
+ if device_map is not None and not is_accelerate_available():
652
702
  raise NotImplementedError(
653
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
654
- " `low_cpu_mem_usage=False`."
703
+ "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
655
704
  )
656
705
 
706
+ if device_map is not None and not isinstance(device_map, str):
707
+ raise ValueError("`device_map` must be a string.")
708
+
709
+ if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
710
+ raise NotImplementedError(
711
+ f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
712
+ )
713
+
714
+ if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
715
+ if is_accelerate_version("<", "0.28.0"):
716
+ raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
717
+
657
718
  if low_cpu_mem_usage is False and device_map is not None:
658
719
  raise ValueError(
659
720
  f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
@@ -671,7 +732,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
671
732
  cached_folder = cls.download(
672
733
  pretrained_model_name_or_path,
673
734
  cache_dir=cache_dir,
674
- resume_download=resume_download,
675
735
  force_download=force_download,
676
736
  proxies=proxies,
677
737
  local_files_only=local_files_only,
@@ -689,39 +749,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
689
749
  else:
690
750
  cached_folder = pretrained_model_name_or_path
691
751
 
752
+ # The variant filenames can have the legacy sharding checkpoint format that we check and throw
753
+ # a warning if detected.
754
+ if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
755
+ warn_msg = (
756
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
757
+ "Please check your files carefully:\n\n"
758
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
759
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
760
+ "If you find any files in the deprecated format:\n"
761
+ "1. Remove all existing checkpoint files for this variant.\n"
762
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
763
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
764
+ )
765
+ logger.warning(warn_msg)
766
+
692
767
  config_dict = cls.load_config(cached_folder)
693
768
 
694
769
  # pop out "_ignore_files" as it is only needed for download
695
770
  config_dict.pop("_ignore_files", None)
696
771
 
697
772
  # 2. Define which model components should load variants
698
- # We retrieve the information by matching whether variant
699
- # model checkpoints exist in the subfolders
700
- model_variants = {}
701
- if variant is not None:
702
- for folder in os.listdir(cached_folder):
703
- folder_path = os.path.join(cached_folder, folder)
704
- is_folder = os.path.isdir(folder_path) and folder in config_dict
705
- variant_exists = is_folder and any(
706
- p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
707
- )
708
- if variant_exists:
709
- model_variants[folder] = variant
773
+ # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
774
+ # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
775
+ # with variant being `"fp16"`.
776
+ model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
777
+ if len(model_variants) == 0 and variant is not None:
778
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
779
+ raise ValueError(error_message)
710
780
 
711
781
  # 3. Load the pipeline class, if using custom module then load it from the hub
712
782
  # if we load from explicit class, let's use it
713
- custom_class_name = None
714
- if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
715
- custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
716
- elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
717
- os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
718
- ):
719
- custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
720
- custom_class_name = config_dict["_class_name"][1]
721
-
783
+ custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
784
+ folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
785
+ )
722
786
  pipeline_class = _get_pipeline_class(
723
787
  cls,
724
- config_dict,
788
+ config=config_dict,
725
789
  load_connected_pipeline=load_connected_pipeline,
726
790
  custom_pipeline=custom_pipeline,
727
791
  class_name=custom_class_name,
@@ -729,24 +793,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
729
793
  revision=custom_revision,
730
794
  )
731
795
 
796
+ if device_map is not None and pipeline_class._load_connected_pipes:
797
+ raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
798
+
732
799
  # DEPRECATED: To be removed in 1.0.0
733
- if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
734
- version.parse(config_dict["_diffusers_version"]).base_version
735
- ) <= version.parse("0.5.1"):
736
- from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
737
-
738
- pipeline_class = StableDiffusionInpaintPipelineLegacy
739
-
740
- deprecation_message = (
741
- "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
742
- f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
743
- " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
744
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
745
- f" checkpoint {pretrained_model_name_or_path} to the format of"
746
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
747
- " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
748
- )
749
- deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
800
+ # we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
801
+ # when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
802
+ _maybe_raise_warning_for_inpainting(
803
+ pipeline_class=pipeline_class,
804
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
805
+ config=config_dict,
806
+ )
750
807
 
751
808
  # 4. Define expected modules given pipeline signature
752
809
  # and define non-None initialized modules (=`init_kwargs`)
@@ -755,9 +812,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
755
812
  # in this case they are already instantiated in `kwargs`
756
813
  # extract them here
757
814
  expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815
+ expected_types = pipeline_class._get_signature_types()
758
816
  passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
759
817
  passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
760
-
761
818
  init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
762
819
 
763
820
  # define init kwargs and make sure that optional component modules are filtered out
@@ -778,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
778
835
 
779
836
  init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
780
837
 
838
+ for key in init_dict.keys():
839
+ if key not in passed_class_obj:
840
+ continue
841
+ if "scheduler" in key:
842
+ continue
843
+
844
+ class_obj = passed_class_obj[key]
845
+ _expected_class_types = []
846
+ for expected_type in expected_types[key]:
847
+ if isinstance(expected_type, enum.EnumMeta):
848
+ _expected_class_types.extend(expected_type.__members__.keys())
849
+ else:
850
+ _expected_class_types.append(expected_type.__name__)
851
+
852
+ _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853
+ if not _is_valid_type:
854
+ logger.warning(
855
+ f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856
+ )
857
+
781
858
  # Special case: safety_checker must be loaded separately when using `from_flax`
782
859
  if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
783
860
  raise NotImplementedError(
@@ -795,17 +872,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
795
872
  # import it here to avoid circular import
796
873
  from diffusers import pipelines
797
874
 
798
- # 6. Load each module in the pipeline
875
+ # 6. device map delegation
876
+ final_device_map = None
877
+ if device_map is not None:
878
+ final_device_map = _get_final_device_map(
879
+ device_map=device_map,
880
+ pipeline_class=pipeline_class,
881
+ passed_class_obj=passed_class_obj,
882
+ init_dict=init_dict,
883
+ library=library,
884
+ max_memory=max_memory,
885
+ torch_dtype=torch_dtype,
886
+ cached_folder=cached_folder,
887
+ force_download=force_download,
888
+ proxies=proxies,
889
+ local_files_only=local_files_only,
890
+ token=token,
891
+ revision=revision,
892
+ )
893
+
894
+ # 7. Load each module in the pipeline
895
+ current_device_map = None
799
896
  for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
800
- # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
897
+ # 7.1 device_map shenanigans
898
+ if final_device_map is not None and len(final_device_map) > 0:
899
+ component_device = final_device_map.get(name, None)
900
+ if component_device is not None:
901
+ current_device_map = {"": component_device}
902
+ else:
903
+ current_device_map = None
904
+
905
+ # 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
801
906
  class_name = class_name[4:] if class_name.startswith("Flax") else class_name
802
907
 
803
- # 6.2 Define all importable classes
908
+ # 7.3 Define all importable classes
804
909
  is_pipeline_module = hasattr(pipelines, library_name)
805
910
  importable_classes = ALL_IMPORTABLE_CLASSES
806
911
  loaded_sub_model = None
807
912
 
808
- # 6.3 Use passed sub model or load class_name from library_name
913
+ # 7.4 Use passed sub model or load class_name from library_name
809
914
  if name in passed_class_obj:
810
915
  # if the model is in a pipeline module, then we load it from the pipeline
811
916
  # check that passed_class_obj has correct parent class
@@ -826,7 +931,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
826
931
  torch_dtype=torch_dtype,
827
932
  provider=provider,
828
933
  sess_options=sess_options,
829
- device_map=device_map,
934
+ device_map=current_device_map,
830
935
  max_memory=max_memory,
831
936
  offload_folder=offload_folder,
832
937
  offload_state_dict=offload_state_dict,
@@ -836,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
836
941
  variant=variant,
837
942
  low_cpu_mem_usage=low_cpu_mem_usage,
838
943
  cached_folder=cached_folder,
944
+ use_safetensors=use_safetensors,
839
945
  )
840
946
  logger.info(
841
947
  f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -843,57 +949,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
843
949
 
844
950
  init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
845
951
 
952
+ # 8. Handle connected pipelines.
846
953
  if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
847
- modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
848
- connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
849
- load_kwargs = {
850
- "cache_dir": cache_dir,
851
- "resume_download": resume_download,
852
- "force_download": force_download,
853
- "proxies": proxies,
854
- "local_files_only": local_files_only,
855
- "token": token,
856
- "revision": revision,
857
- "torch_dtype": torch_dtype,
858
- "custom_pipeline": custom_pipeline,
859
- "custom_revision": custom_revision,
860
- "provider": provider,
861
- "sess_options": sess_options,
862
- "device_map": device_map,
863
- "max_memory": max_memory,
864
- "offload_folder": offload_folder,
865
- "offload_state_dict": offload_state_dict,
866
- "low_cpu_mem_usage": low_cpu_mem_usage,
867
- "variant": variant,
868
- "use_safetensors": use_safetensors,
869
- }
870
-
871
- def get_connected_passed_kwargs(prefix):
872
- connected_passed_class_obj = {
873
- k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
874
- }
875
- connected_passed_pipe_kwargs = {
876
- k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
877
- }
878
-
879
- connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
880
- return connected_passed_kwargs
881
-
882
- connected_pipes = {
883
- prefix: DiffusionPipeline.from_pretrained(
884
- repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
885
- )
886
- for prefix, repo_id in connected_pipes.items()
887
- if repo_id is not None
888
- }
889
-
890
- for prefix, connected_pipe in connected_pipes.items():
891
- # add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
892
- init_kwargs.update(
893
- {"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
894
- )
954
+ init_kwargs = _update_init_kwargs_with_connected_pipeline(
955
+ init_kwargs=init_kwargs,
956
+ passed_pipe_kwargs=passed_pipe_kwargs,
957
+ passed_class_objs=passed_class_obj,
958
+ folder=cached_folder,
959
+ **kwargs_copied,
960
+ )
895
961
 
896
- # 7. Potentially add passed objects if expected
962
+ # 9. Potentially add passed objects if expected
897
963
  missing_modules = set(expected_modules) - set(init_kwargs.keys())
898
964
  passed_modules = list(passed_class_obj.keys())
899
965
  optional_modules = pipeline_class._optional_components
@@ -906,11 +972,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
906
972
  f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
907
973
  )
908
974
 
909
- # 8. Instantiate the pipeline
975
+ # 10. Instantiate the pipeline
910
976
  model = pipeline_class(**init_kwargs)
911
977
 
912
- # 9. Save where the model was instantiated from
978
+ # 11. Save where the model was instantiated from
913
979
  model.register_to_config(_name_or_path=pretrained_model_name_or_path)
980
+ if device_map is not None:
981
+ setattr(model, "hf_device_map", final_device_map)
914
982
  return model
915
983
 
916
984
  @property
@@ -939,6 +1007,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
939
1007
  return torch.device(module._hf_hook.execution_device)
940
1008
  return self.device
941
1009
 
1010
+ def remove_all_hooks(self):
1011
+ r"""
1012
+ Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
1013
+ """
1014
+ for _, model in self.components.items():
1015
+ if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
1016
+ accelerate.hooks.remove_hook_from_module(model, recurse=True)
1017
+ self._all_hooks = []
1018
+
942
1019
  def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
943
1020
  r"""
944
1021
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -953,6 +1030,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
953
1030
  The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
954
1031
  default to "cuda".
955
1032
  """
1033
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1034
+ if is_pipeline_device_mapped:
1035
+ raise ValueError(
1036
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
1037
+ )
1038
+
956
1039
  if self.model_cpu_offload_seq is None:
957
1040
  raise ValueError(
958
1041
  "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
@@ -963,6 +1046,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
963
1046
  else:
964
1047
  raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
965
1048
 
1049
+ self.remove_all_hooks()
1050
+
966
1051
  torch_device = torch.device(device)
967
1052
  device_index = torch_device.index
968
1053
 
@@ -979,11 +1064,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
979
1064
  device = torch.device(f"{device_type}:{self._offload_gpu_id}")
980
1065
  self._offload_device = device
981
1066
 
982
- if self.device.type != "cpu":
983
- self.to("cpu", silence_dtype_warnings=True)
984
- device_mod = getattr(torch, self.device.type, None)
985
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
986
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1067
+ self.to("cpu", silence_dtype_warnings=True)
1068
+ device_mod = getattr(torch, device.type, None)
1069
+ if hasattr(device_mod, "empty_cache") and device_mod.is_available():
1070
+ device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
987
1071
 
988
1072
  all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
989
1073
 
@@ -991,9 +1075,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
991
1075
  hook = None
992
1076
  for model_str in self.model_cpu_offload_seq.split("->"):
993
1077
  model = all_model_components.pop(model_str, None)
1078
+
994
1079
  if not isinstance(model, torch.nn.Module):
995
1080
  continue
996
1081
 
1082
+ # This is because the model would already be placed on a CUDA device.
1083
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
1084
+ if is_loaded_in_8bit_bnb:
1085
+ logger.info(
1086
+ f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
1087
+ )
1088
+ continue
1089
+
997
1090
  _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
998
1091
  self._all_hooks.append(hook)
999
1092
 
@@ -1021,11 +1114,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1021
1114
  # `enable_model_cpu_offload` has not be called, so silently do nothing
1022
1115
  return
1023
1116
 
1024
- for hook in self._all_hooks:
1025
- # offload model and remove hook from model
1026
- hook.offload()
1027
- hook.remove()
1028
-
1029
1117
  # make sure the model is in the same state as before calling it
1030
1118
  self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1031
1119
 
@@ -1048,6 +1136,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1048
1136
  from accelerate import cpu_offload
1049
1137
  else:
1050
1138
  raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
1139
+ self.remove_all_hooks()
1140
+
1141
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1142
+ if is_pipeline_device_mapped:
1143
+ raise ValueError(
1144
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
1145
+ )
1051
1146
 
1052
1147
  torch_device = torch.device(device)
1053
1148
  device_index = torch_device.index
@@ -1083,6 +1178,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1083
1178
  offload_buffers = len(model._parameters) > 0
1084
1179
  cpu_offload(model, device, offload_buffers=offload_buffers)
1085
1180
 
1181
+ def reset_device_map(self):
1182
+ r"""
1183
+ Resets the device maps (if any) to None.
1184
+ """
1185
+ if self.hf_device_map is None:
1186
+ return
1187
+ else:
1188
+ self.remove_all_hooks()
1189
+ for name, component in self.components.items():
1190
+ if isinstance(component, torch.nn.Module):
1191
+ component.to("cpu")
1192
+ self.hf_device_map = None
1193
+
1086
1194
  @classmethod
1087
1195
  @validate_hf_hub_args
1088
1196
  def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
@@ -1121,9 +1229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1121
1229
  force_download (`bool`, *optional*, defaults to `False`):
1122
1230
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1123
1231
  cached versions if they exist.
1124
- resume_download (`bool`, *optional*, defaults to `False`):
1125
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1126
- incompletely downloaded files are deleted.
1232
+
1127
1233
  proxies (`Dict[str, str]`, *optional*):
1128
1234
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1129
1235
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -1176,7 +1282,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1176
1282
 
1177
1283
  """
1178
1284
  cache_dir = kwargs.pop("cache_dir", None)
1179
- resume_download = kwargs.pop("resume_download", False)
1180
1285
  force_download = kwargs.pop("force_download", False)
1181
1286
  proxies = kwargs.pop("proxies", None)
1182
1287
  local_files_only = kwargs.pop("local_files_only", None)
@@ -1209,6 +1314,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1209
1314
  model_info_call_error = e # save error to reraise it if model is not cached locally
1210
1315
 
1211
1316
  if not local_files_only:
1317
+ filenames = {sibling.rfilename for sibling in info.siblings}
1318
+ if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
1319
+ warn_msg = (
1320
+ f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
1321
+ "Please check your files carefully:\n\n"
1322
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
1323
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
1324
+ "If you find any files in the deprecated format:\n"
1325
+ "1. Remove all existing checkpoint files for this variant.\n"
1326
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
1327
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
1328
+ )
1329
+ logger.warning(warn_msg)
1330
+
1331
+ model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1332
+
1212
1333
  config_file = hf_hub_download(
1213
1334
  pretrained_model_name,
1214
1335
  cls.config_name,
@@ -1216,59 +1337,24 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1216
1337
  revision=revision,
1217
1338
  proxies=proxies,
1218
1339
  force_download=force_download,
1219
- resume_download=resume_download,
1220
1340
  token=token,
1221
1341
  )
1222
1342
 
1223
1343
  config_dict = cls._dict_from_json_file(config_file)
1224
1344
  ignore_filenames = config_dict.pop("_ignore_files", [])
1225
1345
 
1226
- # retrieve all folder_names that contain relevant files
1227
- folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
1228
-
1229
- filenames = {sibling.rfilename for sibling in info.siblings}
1230
- model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1231
-
1232
- diffusers_module = importlib.import_module(__name__.split(".")[0])
1233
- pipelines = getattr(diffusers_module, "pipelines")
1234
-
1235
- # optionally create a custom component <> custom file mapping
1236
- custom_components = {}
1237
- for component in folder_names:
1238
- module_candidate = config_dict[component][0]
1239
-
1240
- if module_candidate is None or not isinstance(module_candidate, str):
1241
- continue
1242
-
1243
- # We compute candidate file path on the Hub. Do not use `os.path.join`.
1244
- candidate_file = f"{component}/{module_candidate}.py"
1245
-
1246
- if candidate_file in filenames:
1247
- custom_components[component] = module_candidate
1248
- elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
1249
- raise ValueError(
1250
- f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
1251
- )
1252
-
1253
- if len(variant_filenames) == 0 and variant is not None:
1254
- deprecation_message = (
1255
- f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1256
- f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1257
- "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1258
- "modeling files is deprecated."
1259
- )
1260
- deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
1261
-
1262
1346
  # remove ignored filenames
1263
1347
  model_filenames = set(model_filenames) - set(ignore_filenames)
1264
1348
  variant_filenames = set(variant_filenames) - set(ignore_filenames)
1265
1349
 
1266
- # if the whole pipeline is cached we don't have to ping the Hub
1267
1350
  if revision in DEPRECATED_REVISION_ARGS and version.parse(
1268
1351
  version.parse(__version__).base_version
1269
1352
  ) >= version.parse("0.22.0"):
1270
1353
  warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
1271
1354
 
1355
+ custom_components, folder_names = _get_custom_components_and_folders(
1356
+ pretrained_model_name, config_dict, filenames, variant_filenames, variant
1357
+ )
1272
1358
  model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
1273
1359
 
1274
1360
  custom_class_name = None
@@ -1328,49 +1414,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1328
1414
  expected_components, _ = cls._get_signature_keys(pipeline_class)
1329
1415
  passed_components = [k for k in expected_components if k in kwargs]
1330
1416
 
1331
- if (
1332
- use_safetensors
1333
- and not allow_pickle
1334
- and not is_safetensors_compatible(
1335
- model_filenames, variant=variant, passed_components=passed_components
1336
- )
1337
- ):
1338
- raise EnvironmentError(
1339
- f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
1340
- )
1341
- if from_flax:
1342
- ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
1343
- elif use_safetensors and is_safetensors_compatible(
1344
- model_filenames, variant=variant, passed_components=passed_components
1345
- ):
1346
- ignore_patterns = ["*.bin", "*.msgpack"]
1347
-
1348
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1349
- if not use_onnx:
1350
- ignore_patterns += ["*.onnx", "*.pb"]
1351
-
1352
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
1353
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
1354
- if (
1355
- len(safetensors_variant_filenames) > 0
1356
- and safetensors_model_filenames != safetensors_variant_filenames
1357
- ):
1358
- logger.warning(
1359
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1360
- )
1361
- else:
1362
- ignore_patterns = ["*.safetensors", "*.msgpack"]
1363
-
1364
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
1365
- if not use_onnx:
1366
- ignore_patterns += ["*.onnx", "*.pb"]
1367
-
1368
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
1369
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
1370
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
1371
- logger.warning(
1372
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
1373
- )
1417
+ # retrieve all patterns that should not be downloaded and error out when needed
1418
+ ignore_patterns = _get_ignore_patterns(
1419
+ passed_components,
1420
+ model_folder_names,
1421
+ model_filenames,
1422
+ variant_filenames,
1423
+ use_safetensors,
1424
+ from_flax,
1425
+ allow_pickle,
1426
+ use_onnx,
1427
+ pipeline_class._is_onnx,
1428
+ variant,
1429
+ )
1374
1430
 
1375
1431
  # Don't download any objects that are passed
1376
1432
  allow_patterns = [
@@ -1382,7 +1438,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1382
1438
 
1383
1439
  # Don't download index files of forbidden patterns either
1384
1440
  ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
1385
-
1386
1441
  re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
1387
1442
  re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
1388
1443
 
@@ -1406,7 +1461,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1406
1461
  cached_folder = snapshot_download(
1407
1462
  pretrained_model_name,
1408
1463
  cache_dir=cache_dir,
1409
- resume_download=resume_download,
1410
1464
  proxies=proxies,
1411
1465
  local_files_only=local_files_only,
1412
1466
  token=token,
@@ -1429,7 +1483,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1429
1483
  for connected_pipe_repo_id in connected_pipes:
1430
1484
  download_kwargs = {
1431
1485
  "cache_dir": cache_dir,
1432
- "resume_download": resume_download,
1433
1486
  "force_download": force_download,
1434
1487
  "proxies": proxies,
1435
1488
  "local_files_only": local_files_only,
@@ -1472,6 +1525,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1472
1525
 
1473
1526
  return expected_modules, optional_parameters
1474
1527
 
1528
+ @classmethod
1529
+ def _get_signature_types(cls):
1530
+ signature_types = {}
1531
+ for k, v in inspect.signature(cls.__init__).parameters.items():
1532
+ if inspect.isclass(v.annotation):
1533
+ signature_types[k] = (v.annotation,)
1534
+ elif get_origin(v.annotation) == Union:
1535
+ signature_types[k] = get_args(v.annotation)
1536
+ else:
1537
+ logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
1538
+ return signature_types
1539
+
1475
1540
  @property
1476
1541
  def components(self) -> Dict[str, Any]:
1477
1542
  r"""
@@ -1515,6 +1580,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1515
1580
  """
1516
1581
  return numpy_to_pil(images)
1517
1582
 
1583
+ @torch.compiler.disable
1518
1584
  def progress_bar(self, iterable=None, total=None):
1519
1585
  if not hasattr(self, "_progress_bar_config"):
1520
1586
  self._progress_bar_config = {}
@@ -1650,6 +1716,129 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1650
1716
  for module in modules:
1651
1717
  module.set_attention_slice(slice_size)
1652
1718
 
1719
+ @classmethod
1720
+ def from_pipe(cls, pipeline, **kwargs):
1721
+ r"""
1722
+ Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
1723
+ pipeline components without reallocating additional memory.
1724
+
1725
+ Arguments:
1726
+ pipeline (`DiffusionPipeline`):
1727
+ The pipeline from which to create a new pipeline.
1728
+
1729
+ Returns:
1730
+ `DiffusionPipeline`:
1731
+ A new pipeline with the same weights and configurations as `pipeline`.
1732
+
1733
+ Examples:
1734
+
1735
+ ```py
1736
+ >>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
1737
+
1738
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
1739
+ >>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
1740
+ ```
1741
+ """
1742
+
1743
+ original_config = dict(pipeline.config)
1744
+ torch_dtype = kwargs.pop("torch_dtype", None)
1745
+
1746
+ # derive the pipeline class to instantiate
1747
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
1748
+ custom_revision = kwargs.pop("custom_revision", None)
1749
+
1750
+ if custom_pipeline is not None:
1751
+ pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
1752
+ else:
1753
+ pipeline_class = cls
1754
+
1755
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
1756
+ # true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
1757
+ # e.g. `image_encoder` for StableDiffusionPipeline
1758
+ parameters = inspect.signature(cls.__init__).parameters
1759
+ true_optional_modules = set(
1760
+ {k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
1761
+ )
1762
+
1763
+ # get the class of each component based on its type hint
1764
+ # e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
1765
+ component_types = pipeline_class._get_signature_types()
1766
+
1767
+ pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
1768
+ # allow users pass modules in `kwargs` to override the original pipeline's components
1769
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
1770
+
1771
+ original_class_obj = {}
1772
+ for name, component in pipeline.components.items():
1773
+ if name in expected_modules and name not in passed_class_obj:
1774
+ # for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
1775
+ if (
1776
+ not isinstance(component, ModelMixin)
1777
+ or type(component) in component_types[name]
1778
+ or (component is None and name in cls._optional_components)
1779
+ ):
1780
+ original_class_obj[name] = component
1781
+ else:
1782
+ logger.warning(
1783
+ f"component {name} is not switched over to new pipeline because type does not match the expected."
1784
+ f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
1785
+ f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
1786
+ )
1787
+
1788
+ # allow users pass optional kwargs to override the original pipelines config attribute
1789
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
1790
+ original_pipe_kwargs = {
1791
+ k: original_config[k]
1792
+ for k in original_config.keys()
1793
+ if k in optional_kwargs and k not in passed_pipe_kwargs
1794
+ }
1795
+
1796
+ # config attribute that were not expected by pipeline is stored as its private attribute
1797
+ # (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
1798
+ # in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
1799
+ additional_pipe_kwargs = [
1800
+ k[1:]
1801
+ for k in original_config.keys()
1802
+ if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
1803
+ ]
1804
+ for k in additional_pipe_kwargs:
1805
+ original_pipe_kwargs[k] = original_config.pop(f"_{k}")
1806
+
1807
+ pipeline_kwargs = {
1808
+ **passed_class_obj,
1809
+ **original_class_obj,
1810
+ **passed_pipe_kwargs,
1811
+ **original_pipe_kwargs,
1812
+ **kwargs,
1813
+ }
1814
+
1815
+ # store unused config as private attribute in the new pipeline
1816
+ unused_original_config = {
1817
+ f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
1818
+ }
1819
+
1820
+ missing_modules = (
1821
+ set(expected_modules)
1822
+ - set(pipeline._optional_components)
1823
+ - set(pipeline_kwargs.keys())
1824
+ - set(true_optional_modules)
1825
+ )
1826
+
1827
+ if len(missing_modules) > 0:
1828
+ raise ValueError(
1829
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
1830
+ )
1831
+
1832
+ new_pipeline = pipeline_class(**pipeline_kwargs)
1833
+ if pretrained_model_name_or_path is not None:
1834
+ new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
1835
+ new_pipeline.register_to_config(**unused_original_config)
1836
+
1837
+ if torch_dtype is not None:
1838
+ new_pipeline.to(dtype=torch_dtype)
1839
+
1840
+ return new_pipeline
1841
+
1653
1842
 
1654
1843
  class StableDiffusionMixin:
1655
1844
  r"""
@@ -1713,8 +1902,8 @@ class StableDiffusionMixin:
1713
1902
 
1714
1903
  def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
1715
1904
  """
1716
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
1717
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1905
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
1906
+ are fused. For cross-attention modules, key and value projection matrices are fused.
1718
1907
 
1719
1908
  <Tip warning={true}>
1720
1909