diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
26
26
  from ...utils import deprecate, is_torch_xla_available, logging
27
27
  from ...utils.torch_utils import randn_tensor
28
- from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
28
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
29
29
  from ..stable_diffusion import StableDiffusionPipelineOutput
30
30
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
31
  from .image_encoder import PaintByExampleImageEncoder
@@ -155,7 +155,8 @@ def prepare_mask_and_masked_image(image, mask):
155
155
  return mask, masked_image
156
156
 
157
157
 
158
- class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
158
+ class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
159
+ _last_supported_version = "0.33.1"
159
160
  r"""
160
161
  <Tip warning={true}>
161
162
 
@@ -239,7 +240,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
239
240
  def prepare_extra_step_kwargs(self, generator, eta):
240
241
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
241
242
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
242
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
243
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
243
244
  # and should be between [0, 1]
244
245
 
245
246
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -447,8 +448,8 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
447
448
  num_images_per_prompt (`int`, *optional*, defaults to 1):
448
449
  The number of images to generate per prompt.
449
450
  eta (`float`, *optional*, defaults to 0.0):
450
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
451
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
451
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
452
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
452
453
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
453
454
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
454
455
  generation deterministic.
@@ -521,7 +522,7 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
521
522
  batch_size = image.shape[0]
522
523
  device = self._execution_device
523
524
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
524
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
525
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
525
526
  # corresponds to doing no classifier free guidance.
526
527
  do_classifier_free_guidance = guidance_scale > 1.0
527
528
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@ from ...utils import (
46
46
  from ...utils.torch_utils import randn_tensor
47
47
  from ...video_processor import VideoProcessor
48
48
  from ..free_init_utils import FreeInitMixin
49
- from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
49
+ from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
50
50
 
51
51
 
52
52
  if is_torch_xla_available():
@@ -132,6 +132,7 @@ class PIAPipelineOutput(BaseOutput):
132
132
 
133
133
 
134
134
  class PIAPipeline(
135
+ DeprecatedPipelineMixin,
135
136
  DiffusionPipeline,
136
137
  StableDiffusionMixin,
137
138
  TextualInversionLoaderMixin,
@@ -140,6 +141,7 @@ class PIAPipeline(
140
141
  FromSingleFileMixin,
141
142
  FreeInitMixin,
142
143
  ):
144
+ _last_supported_version = "0.33.1"
143
145
  r"""
144
146
  Pipeline for text-to-video generation.
145
147
 
@@ -432,7 +434,7 @@ class PIAPipeline(
432
434
  def prepare_extra_step_kwargs(self, generator, eta):
433
435
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
434
436
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
435
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
437
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
436
438
  # and should be between [0, 1]
437
439
 
438
440
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -653,7 +655,7 @@ class PIAPipeline(
653
655
  return self._clip_skip
654
656
 
655
657
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
656
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
658
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
657
659
  # corresponds to doing no classifier free guidance.
658
660
  @property
659
661
  def do_classifier_free_guidance(self):
@@ -723,8 +725,8 @@ class PIAPipeline(
723
725
  The prompt or prompts to guide what to not include in image generation. If not defined, you need to
724
726
  pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
725
727
  eta (`float`, *optional*, defaults to 0.0):
726
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
727
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
728
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
729
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
728
730
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
729
731
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
730
732
  generation deterministic.
@@ -248,9 +248,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
248
248
  pretrained pipeline hosted on the Hub.
249
249
  - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
250
250
  using [`~FlaxDiffusionPipeline.save_pretrained`].
251
- dtype (`str` or `jnp.dtype`, *optional*):
252
- Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
253
- automatically derived from the model's weights.
251
+ dtype (`jnp.dtype`, *optional*):
252
+ Override the default `jnp.dtype` and load the model under this dtype.
254
253
  force_download (`bool`, *optional*, defaults to `False`):
255
254
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
256
255
  cached versions if they exist.
@@ -469,7 +468,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
469
468
  class_obj = import_flax_or_no_model(pipeline_module, class_name)
470
469
 
471
470
  importable_classes = ALL_IMPORTABLE_CLASSES
472
- class_candidates = {c: class_obj for c in importable_classes.keys()}
471
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
473
472
  else:
474
473
  # else we just import it from the library.
475
474
  library = importlib.import_module(library_name)
@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES:
92
92
  ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
93
93
 
94
94
 
95
- def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
95
+ def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
96
96
  """
97
97
  Checking for safetensors compatibility:
98
98
  - The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
103
103
  - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
104
104
  extension is replaced with ".safetensors"
105
105
  """
106
+ weight_names = [
107
+ WEIGHTS_NAME,
108
+ SAFETENSORS_WEIGHTS_NAME,
109
+ FLAX_WEIGHTS_NAME,
110
+ ONNX_WEIGHTS_NAME,
111
+ ONNX_EXTERNAL_WEIGHTS_NAME,
112
+ ]
113
+
114
+ if is_transformers_available():
115
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
116
+
117
+ # model_pytorch, diffusion_model_pytorch, ...
118
+ weight_prefixes = [w.split(".")[0] for w in weight_names]
119
+ # .bin, .safetensors, ...
120
+ weight_suffixs = [w.split(".")[-1] for w in weight_names]
121
+ # -00001-of-00002
122
+ transformers_index_format = r"\d{5}-of-\d{5}"
123
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
124
+ variant_file_re = re.compile(
125
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
126
+ )
127
+ non_variant_file_re = re.compile(
128
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
129
+ )
130
+
106
131
  passed_components = passed_components or []
107
132
  if folder_names:
108
133
  filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -121,15 +146,29 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
121
146
  components[component].append(component_filename)
122
147
 
123
148
  # If there are no component folders check the main directory for safetensors files
149
+ filtered_filenames = set()
124
150
  if not components:
125
- return any(".safetensors" in filename for filename in filenames)
151
+ if variant is not None:
152
+ filtered_filenames = filter_with_regex(filenames, variant_file_re)
153
+
154
+ # If no variant filenames exist check if non-variant files are available
155
+ if not filtered_filenames:
156
+ filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
157
+ return any(".safetensors" in filename for filename in filtered_filenames)
126
158
 
127
159
  # iterate over all files of a component
128
160
  # check if safetensor files exist for that component
129
- # if variant is provided check if the variant of the safetensors exists
130
161
  for component, component_filenames in components.items():
131
162
  matches = []
132
- for component_filename in component_filenames:
163
+ filtered_component_filenames = set()
164
+ # if variant is provided check if the variant of the safetensors exists
165
+ if variant is not None:
166
+ filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
167
+
168
+ # if variant safetensor files do not exist check for non-variants
169
+ if not filtered_component_filenames:
170
+ filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
171
+ for component_filename in filtered_component_filenames:
133
172
  filename, extension = os.path.splitext(component_filename)
134
173
 
135
174
  match_exists = extension == ".safetensors"
@@ -159,6 +198,10 @@ def filter_model_files(filenames):
159
198
  return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
160
199
 
161
200
 
201
+ def filter_with_regex(filenames, pattern_re):
202
+ return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
203
+
204
+
162
205
  def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
163
206
  weight_names = [
164
207
  WEIGHTS_NAME,
@@ -207,9 +250,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
207
250
  # interested in the extension name
208
251
  return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
209
252
 
210
- def filter_with_regex(filenames, pattern_re):
211
- return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
212
-
213
253
  # Group files by component
214
254
  components = {}
215
255
  for filename in filenames:
@@ -335,19 +375,19 @@ def get_class_obj_and_candidates(
335
375
  library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
336
376
  ):
337
377
  """Simple helper method to retrieve class object of module as well as potential parent class objects"""
338
- component_folder = os.path.join(cache_dir, component_name)
378
+ component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
339
379
 
340
380
  if is_pipeline_module:
341
381
  pipeline_module = getattr(pipelines, library_name)
342
382
 
343
383
  class_obj = getattr(pipeline_module, class_name)
344
- class_candidates = {c: class_obj for c in importable_classes.keys()}
345
- elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
384
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
385
+ elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
346
386
  # load custom component
347
387
  class_obj = get_class_from_dynamic_module(
348
388
  component_folder, module_file=library_name + ".py", class_name=class_name
349
389
  )
350
- class_candidates = {c: class_obj for c in importable_classes.keys()}
390
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
351
391
  else:
352
392
  # else we just import it from the library.
353
393
  library = importlib.import_module(library_name)
@@ -675,8 +715,10 @@ def load_sub_model(
675
715
  use_safetensors: bool,
676
716
  dduf_entries: Optional[Dict[str, DDUFEntry]],
677
717
  provider_options: Any,
718
+ quantization_config: Optional[Any] = None,
678
719
  ):
679
720
  """Helper method to load the module `name` from `library_name` and `class_name`"""
721
+ from ..quantizers import PipelineQuantizationConfig
680
722
 
681
723
  # retrieve class candidates
682
724
 
@@ -769,6 +811,17 @@ def load_sub_model(
769
811
  else:
770
812
  loading_kwargs["low_cpu_mem_usage"] = False
771
813
 
814
+ if (
815
+ quantization_config is not None
816
+ and isinstance(quantization_config, PipelineQuantizationConfig)
817
+ and issubclass(class_obj, torch.nn.Module)
818
+ ):
819
+ model_quant_config = quantization_config._resolve_quant_config(
820
+ is_diffusers=is_diffusers_model, module_name=name
821
+ )
822
+ if model_quant_config is not None:
823
+ loading_kwargs["quantization_config"] = model_quant_config
824
+
772
825
  # check if the module is in a subdirectory
773
826
  if dduf_entries:
774
827
  loading_kwargs["dduf_entries"] = dduf_entries
@@ -984,7 +1037,7 @@ def _get_ignore_patterns(
984
1037
  use_safetensors
985
1038
  and not allow_pickle
986
1039
  and not is_safetensors_compatible(
987
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
1040
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
988
1041
  )
989
1042
  ):
990
1043
  raise EnvironmentError(
@@ -995,7 +1048,7 @@ def _get_ignore_patterns(
995
1048
  ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
996
1049
 
997
1050
  elif use_safetensors and is_safetensors_compatible(
998
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
1051
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
999
1052
  ):
1000
1053
  ignore_patterns = ["*.bin", "*.msgpack"]
1001
1054
 
@@ -1078,3 +1131,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
1078
1131
  break
1079
1132
  if has_transformers_component and not is_transformers_version(">", "4.47.1"):
1080
1133
  raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
1134
+
1135
+
1136
+ def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
1137
+ if quant_config is None:
1138
+ return
1139
+
1140
+ actual_pipe_components = set(pipe_init_dict.keys())
1141
+ missing = ""
1142
+ quant_components = None
1143
+ if getattr(quant_config, "components_to_quantize", None) is not None:
1144
+ quant_components = set(quant_config.components_to_quantize)
1145
+ elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
1146
+ quant_components = set(quant_config.quant_mapping.keys())
1147
+
1148
+ if quant_components and not quant_components.issubset(actual_pipe_components):
1149
+ missing = quant_components - actual_pipe_components
1150
+
1151
+ if missing:
1152
+ logger.warning(
1153
+ f"The following components in the quantization config {missing} will be ignored "
1154
+ "as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
1155
+ f"components are: {', '.join(actual_pipe_components)}."
1156
+ )
@@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin
47
47
  from ..models import AutoencoderKL
48
48
  from ..models.attention_processor import FusedAttnProcessor2_0
49
49
  from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
50
+ from ..quantizers import PipelineQuantizationConfig
50
51
  from ..quantizers.bitsandbytes.utils import _check_bnb_status
51
52
  from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
52
53
  from ..utils import (
@@ -58,6 +59,7 @@ from ..utils import (
58
59
  _is_valid_type,
59
60
  is_accelerate_available,
60
61
  is_accelerate_version,
62
+ is_hpu_available,
61
63
  is_torch_npu_available,
62
64
  is_torch_version,
63
65
  is_transformers_version,
@@ -65,7 +67,7 @@ from ..utils import (
65
67
  numpy_to_pil,
66
68
  )
67
69
  from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
68
- from ..utils.torch_utils import is_compiled_module
70
+ from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
69
71
 
70
72
 
71
73
  if is_torch_npu_available():
@@ -86,6 +88,7 @@ from .pipeline_loading_utils import (
86
88
  _identify_model_variants,
87
89
  _maybe_raise_error_for_incorrect_transformers,
88
90
  _maybe_raise_warning_for_inpainting,
91
+ _maybe_warn_for_wrong_component_in_quant_config,
89
92
  _resolve_custom_pipeline_and_cls,
90
93
  _unwrap_model,
91
94
  _update_init_kwargs_with_connected_pipeline,
@@ -137,6 +140,43 @@ class AudioPipelineOutput(BaseOutput):
137
140
  audios: np.ndarray
138
141
 
139
142
 
143
+ class DeprecatedPipelineMixin:
144
+ """
145
+ A mixin that can be used to mark a pipeline as deprecated.
146
+
147
+ Pipelines inheriting from this mixin will raise a warning when instantiated, indicating that they are deprecated
148
+ and won't receive updates past the specified version. Tests will be skipped for pipelines that inherit from this
149
+ mixin.
150
+
151
+ Example usage:
152
+ ```python
153
+ class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
154
+ _last_supported_version = "0.20.0"
155
+
156
+ def __init__(self, *args, **kwargs):
157
+ super().__init__(*args, **kwargs)
158
+ ```
159
+ """
160
+
161
+ # Override this in the inheriting class to specify the last version that will support this pipeline
162
+ _last_supported_version = None
163
+
164
+ def __init__(self, *args, **kwargs):
165
+ # Get the class name for the warning message
166
+ class_name = self.__class__.__name__
167
+
168
+ # Get the last supported version or use the current version if not specified
169
+ version_info = getattr(self.__class__, "_last_supported_version", __version__)
170
+
171
+ # Raise a warning that this pipeline is deprecated
172
+ logger.warning(
173
+ f"The {class_name} has been deprecated and will not receive bug fixes or feature updates after Diffusers version {version_info}. "
174
+ )
175
+
176
+ # Call the parent class's __init__ method
177
+ super().__init__(*args, **kwargs)
178
+
179
+
140
180
  class DiffusionPipeline(ConfigMixin, PushToHubMixin):
141
181
  r"""
142
182
  Base class for all pipelines.
@@ -404,6 +444,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
404
444
  if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
405
445
  return False
406
446
 
447
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
448
+
449
+ if is_loaded_in_8bit_bnb:
450
+ return False
451
+
407
452
  return hasattr(module, "_hf_hook") and (
408
453
  isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
409
454
  or hasattr(module._hf_hook, "hooks")
@@ -445,6 +490,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
445
490
  f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
446
491
  )
447
492
 
493
+ # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
494
+ if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
495
+ os.environ["PT_HPU_GPU_MIGRATION"] = "1"
496
+ logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
497
+
498
+ import habana_frameworks.torch # noqa: F401
499
+
500
+ # HPU hardware check
501
+ if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
502
+ raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
503
+
504
+ os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
505
+ logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
506
+
448
507
  module_names, _ = self._get_signature_keys(self)
449
508
  modules = [getattr(self, n, None) for n in module_names]
450
509
  modules = [m for m in modules if isinstance(m, torch.nn.Module)]
@@ -552,12 +611,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
552
611
  saved using
553
612
  [`~DiffusionPipeline.save_pretrained`].
554
613
  - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
555
- torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
556
- Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
557
- dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
558
- `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
559
- unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
560
- torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
614
+ torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
615
+ Override the default `torch.dtype` and load the model with another dtype. To load submodels with
616
+ different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
617
+ Set the default dtype for unspecified components with `default` (for example `{'transformer':
618
+ torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
619
+ `torch.float32` is used.
561
620
  custom_pipeline (`str`, *optional*):
562
621
 
563
622
  <Tip warning={true}>
@@ -611,14 +670,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
611
670
  Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
612
671
  guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
613
672
  information.
614
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
615
- A map that specifies where each submodule should go. It doesn’t need to be defined for each
616
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
617
- same device.
618
-
619
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
620
- more information about each option see [designing a device
621
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
673
+ device_map (`str`, *optional*):
674
+ Strategy that dictates how the different components of a pipeline should be placed on available
675
+ devices. Currently, only "balanced" `device_map` is supported. Check out
676
+ [this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
677
+ to know more.
622
678
  max_memory (`Dict`, *optional*):
623
679
  A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
624
680
  each GPU and the available CPU RAM if unset.
@@ -705,6 +761,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
705
761
  use_safetensors = kwargs.pop("use_safetensors", None)
706
762
  use_onnx = kwargs.pop("use_onnx", None)
707
763
  load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
764
+ quantization_config = kwargs.pop("quantization_config", None)
708
765
 
709
766
  if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
710
767
  torch_dtype = torch.float32
@@ -721,6 +778,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
721
778
  " install accelerate\n```\n."
722
779
  )
723
780
 
781
+ if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
782
+ raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
783
+
724
784
  if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
725
785
  raise NotImplementedError(
726
786
  "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
@@ -925,6 +985,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
925
985
 
926
986
  # 7. Load each module in the pipeline
927
987
  current_device_map = None
988
+ _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
928
989
  for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
929
990
  # 7.1 device_map shenanigans
930
991
  if final_device_map is not None and len(final_device_map) > 0:
@@ -981,6 +1042,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
981
1042
  use_safetensors=use_safetensors,
982
1043
  dduf_entries=dduf_entries,
983
1044
  provider_options=provider_options,
1045
+ quantization_config=quantization_config,
984
1046
  )
985
1047
  logger.info(
986
1048
  f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1084,19 +1146,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1084
1146
  accelerate.hooks.remove_hook_from_module(model, recurse=True)
1085
1147
  self._all_hooks = []
1086
1148
 
1087
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1149
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
1088
1150
  r"""
1089
1151
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
1090
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
1091
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
1092
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
1152
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
1153
+ `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
1154
+ lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
1155
+ of the `unet`.
1093
1156
 
1094
1157
  Arguments:
1095
1158
  gpu_id (`int`, *optional*):
1096
1159
  The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1097
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1160
+ device (`torch.Device` or `str`, *optional*, defaults to None):
1098
1161
  The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1099
- default to "cuda".
1162
+ automatically detect the available accelerator and use.
1100
1163
  """
1101
1164
  self._maybe_raise_error_if_group_offload_active(raise_error=True)
1102
1165
 
@@ -1118,6 +1181,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1118
1181
 
1119
1182
  self.remove_all_hooks()
1120
1183
 
1184
+ if device is None:
1185
+ device = get_device()
1186
+ if device == "cpu":
1187
+ raise RuntimeError("`enable_model_cpu_offload` requires accelerator, but not found")
1188
+
1121
1189
  torch_device = torch.device(device)
1122
1190
  device_index = torch_device.index
1123
1191
 
@@ -1135,9 +1203,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1135
1203
  self._offload_device = device
1136
1204
 
1137
1205
  self.to("cpu", silence_dtype_warnings=True)
1138
- device_mod = getattr(torch, device.type, None)
1139
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
1140
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1206
+ empty_device_cache(device.type)
1141
1207
 
1142
1208
  all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
1143
1209
 
@@ -1196,20 +1262,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1196
1262
  # make sure the model is in the same state as before calling it
1197
1263
  self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
1198
1264
 
1199
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
1265
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
1200
1266
  r"""
1201
1267
  Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
1202
1268
  dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
1203
- and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
1204
- method called. Offloading happens on a submodule basis. Memory savings are higher than with
1269
+ and then moved to `torch.device('meta')` and loaded to accelerator only when their specific submodule has its
1270
+ `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with
1205
1271
  `enable_model_cpu_offload`, but performance is lower.
1206
1272
 
1207
1273
  Arguments:
1208
1274
  gpu_id (`int`, *optional*):
1209
1275
  The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
1210
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
1276
+ device (`torch.Device` or `str`, *optional*, defaults to None):
1211
1277
  The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
1212
- default to "cuda".
1278
+ automatically detect the available accelerator and use.
1213
1279
  """
1214
1280
  self._maybe_raise_error_if_group_offload_active(raise_error=True)
1215
1281
 
@@ -1225,6 +1291,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1225
1291
  "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
1226
1292
  )
1227
1293
 
1294
+ if device is None:
1295
+ device = get_device()
1296
+ if device == "cpu":
1297
+ raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found")
1298
+
1228
1299
  torch_device = torch.device(device)
1229
1300
  device_index = torch_device.index
1230
1301
 
@@ -1242,10 +1313,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1242
1313
  self._offload_device = device
1243
1314
 
1244
1315
  if self.device.type != "cpu":
1316
+ orig_device_type = self.device.type
1245
1317
  self.to("cpu", silence_dtype_warnings=True)
1246
- device_mod = getattr(torch, self.device.type, None)
1247
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
1248
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
1318
+ empty_device_cache(orig_device_type)
1249
1319
 
1250
1320
  for name, model in self.components.items():
1251
1321
  if not isinstance(model, torch.nn.Module):
@@ -1628,6 +1698,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1628
1698
  signature_types[k] = (v.annotation,)
1629
1699
  elif get_origin(v.annotation) == Union:
1630
1700
  signature_types[k] = get_args(v.annotation)
1701
+ elif get_origin(v.annotation) in [List, Dict, list, dict]:
1702
+ signature_types[k] = (v.annotation,)
1631
1703
  else:
1632
1704
  logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
1633
1705
  return signature_types
@@ -1990,7 +2062,7 @@ class StableDiffusionMixin:
1990
2062
  self.vae.disable_tiling()
1991
2063
 
1992
2064
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
1993
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
2065
+ r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
1994
2066
 
1995
2067
  The suffixes after the scaling factors represent the stages where they are being applied.
1996
2068