diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (478) hide show
  1. diffusers/__init__.py +48 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/dependency_versions_check.py +1 -1
  7. diffusers/dependency_versions_table.py +1 -1
  8. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  9. diffusers/hooks/faster_cache.py +2 -2
  10. diffusers/hooks/group_offloading.py +128 -29
  11. diffusers/hooks/hooks.py +2 -2
  12. diffusers/hooks/layerwise_casting.py +3 -3
  13. diffusers/hooks/pyramid_attention_broadcast.py +1 -1
  14. diffusers/image_processor.py +7 -2
  15. diffusers/loaders/__init__.py +4 -0
  16. diffusers/loaders/ip_adapter.py +5 -14
  17. diffusers/loaders/lora_base.py +212 -111
  18. diffusers/loaders/lora_conversion_utils.py +275 -34
  19. diffusers/loaders/lora_pipeline.py +1554 -819
  20. diffusers/loaders/peft.py +52 -109
  21. diffusers/loaders/single_file.py +2 -2
  22. diffusers/loaders/single_file_model.py +20 -4
  23. diffusers/loaders/single_file_utils.py +225 -5
  24. diffusers/loaders/textual_inversion.py +3 -2
  25. diffusers/loaders/transformer_flux.py +1 -1
  26. diffusers/loaders/transformer_sd3.py +2 -2
  27. diffusers/loaders/unet.py +2 -16
  28. diffusers/loaders/unet_loader_utils.py +1 -1
  29. diffusers/loaders/utils.py +1 -1
  30. diffusers/models/__init__.py +15 -1
  31. diffusers/models/activations.py +5 -5
  32. diffusers/models/adapter.py +2 -3
  33. diffusers/models/attention.py +4 -4
  34. diffusers/models/attention_flax.py +10 -10
  35. diffusers/models/attention_processor.py +14 -10
  36. diffusers/models/auto_model.py +47 -10
  37. diffusers/models/autoencoders/__init__.py +1 -0
  38. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  39. diffusers/models/autoencoders/autoencoder_dc.py +3 -3
  40. diffusers/models/autoencoders/autoencoder_kl.py +4 -4
  41. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  42. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  43. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
  44. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  45. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  46. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  47. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  48. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  49. diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
  50. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  52. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  53. diffusers/models/autoencoders/vae.py +13 -2
  54. diffusers/models/autoencoders/vq_model.py +2 -2
  55. diffusers/models/cache_utils.py +1 -1
  56. diffusers/models/controlnet.py +1 -1
  57. diffusers/models/controlnet_flux.py +1 -1
  58. diffusers/models/controlnet_sd3.py +1 -1
  59. diffusers/models/controlnet_sparsectrl.py +1 -1
  60. diffusers/models/controlnets/__init__.py +1 -0
  61. diffusers/models/controlnets/controlnet.py +3 -3
  62. diffusers/models/controlnets/controlnet_flax.py +1 -1
  63. diffusers/models/controlnets/controlnet_flux.py +16 -15
  64. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  65. diffusers/models/controlnets/controlnet_sana.py +290 -0
  66. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  67. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  68. diffusers/models/controlnets/controlnet_union.py +1 -1
  69. diffusers/models/controlnets/controlnet_xs.py +7 -7
  70. diffusers/models/controlnets/multicontrolnet.py +4 -5
  71. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  72. diffusers/models/downsampling.py +2 -2
  73. diffusers/models/embeddings.py +10 -12
  74. diffusers/models/embeddings_flax.py +2 -2
  75. diffusers/models/lora.py +3 -3
  76. diffusers/models/modeling_utils.py +44 -14
  77. diffusers/models/normalization.py +4 -4
  78. diffusers/models/resnet.py +2 -2
  79. diffusers/models/resnet_flax.py +1 -1
  80. diffusers/models/transformers/__init__.py +5 -0
  81. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  82. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  83. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  84. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  85. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  86. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  87. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  88. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  89. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  90. diffusers/models/transformers/prior_transformer.py +1 -1
  91. diffusers/models/transformers/sana_transformer.py +8 -3
  92. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  93. diffusers/models/transformers/t5_film_transformer.py +3 -3
  94. diffusers/models/transformers/transformer_2d.py +1 -1
  95. diffusers/models/transformers/transformer_allegro.py +1 -1
  96. diffusers/models/transformers/transformer_chroma.py +742 -0
  97. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  98. diffusers/models/transformers/transformer_cogview4.py +317 -25
  99. diffusers/models/transformers/transformer_cosmos.py +579 -0
  100. diffusers/models/transformers/transformer_flux.py +9 -11
  101. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  102. diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
  103. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  104. diffusers/models/transformers/transformer_ltx.py +2 -2
  105. diffusers/models/transformers/transformer_lumina2.py +1 -1
  106. diffusers/models/transformers/transformer_mochi.py +1 -1
  107. diffusers/models/transformers/transformer_omnigen.py +2 -2
  108. diffusers/models/transformers/transformer_sd3.py +7 -7
  109. diffusers/models/transformers/transformer_temporal.py +1 -1
  110. diffusers/models/transformers/transformer_wan.py +24 -8
  111. diffusers/models/transformers/transformer_wan_vace.py +393 -0
  112. diffusers/models/unets/unet_1d.py +1 -1
  113. diffusers/models/unets/unet_1d_blocks.py +1 -1
  114. diffusers/models/unets/unet_2d.py +1 -1
  115. diffusers/models/unets/unet_2d_blocks.py +1 -1
  116. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  117. diffusers/models/unets/unet_2d_condition.py +2 -2
  118. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  119. diffusers/models/unets/unet_3d_blocks.py +1 -1
  120. diffusers/models/unets/unet_3d_condition.py +3 -3
  121. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  122. diffusers/models/unets/unet_kandinsky3.py +1 -1
  123. diffusers/models/unets/unet_motion_model.py +2 -2
  124. diffusers/models/unets/unet_stable_cascade.py +1 -1
  125. diffusers/models/upsampling.py +2 -2
  126. diffusers/models/vae_flax.py +2 -2
  127. diffusers/models/vq_model.py +1 -1
  128. diffusers/pipelines/__init__.py +37 -6
  129. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  130. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  131. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  132. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  133. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  134. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  135. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  136. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  137. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  138. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  139. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  140. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  141. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
  142. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  143. diffusers/pipelines/auto_pipeline.py +6 -7
  144. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  145. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  146. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  147. diffusers/pipelines/chroma/__init__.py +49 -0
  148. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  149. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  150. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  151. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
  152. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
  153. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
  154. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
  155. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  156. diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
  157. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  158. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  159. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  160. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  161. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  162. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
  163. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  164. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  165. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  166. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  167. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  168. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
  169. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
  170. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
  171. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  172. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  173. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  174. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  175. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  176. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  177. diffusers/pipelines/cosmos/__init__.py +54 -0
  178. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  179. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  180. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  181. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  182. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  183. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  184. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  185. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  186. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  187. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  188. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  189. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  190. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  191. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  192. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  193. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  194. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  195. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  196. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  197. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  198. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  199. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  200. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  201. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  202. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  203. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  204. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
  205. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  206. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  207. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  208. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  209. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  210. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  211. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  212. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  213. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  214. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  215. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  216. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  217. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  218. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  219. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  220. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  221. diffusers/pipelines/flux/modeling_flux.py +1 -1
  222. diffusers/pipelines/flux/pipeline_flux.py +10 -17
  223. diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
  224. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
  225. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
  226. diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
  227. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
  228. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
  229. diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
  230. diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
  231. diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
  232. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  233. diffusers/pipelines/free_init_utils.py +2 -2
  234. diffusers/pipelines/free_noise_utils.py +3 -3
  235. diffusers/pipelines/hidream_image/__init__.py +47 -0
  236. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  237. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  238. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  239. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  240. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
  241. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  242. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  243. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  244. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  245. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  246. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  247. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  248. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  249. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  250. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  251. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  252. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  253. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  254. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  255. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  256. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  257. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  258. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  259. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  260. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  261. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  262. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  263. diffusers/pipelines/kolors/text_encoder.py +3 -3
  264. diffusers/pipelines/kolors/tokenizer.py +1 -1
  265. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  266. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  267. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  268. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  269. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  270. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  271. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  272. diffusers/pipelines/ltx/__init__.py +4 -0
  273. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  274. diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
  275. diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
  276. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
  277. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  278. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  279. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  280. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  281. diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
  282. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  283. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  284. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  285. diffusers/pipelines/onnx_utils.py +15 -2
  286. diffusers/pipelines/pag/pag_utils.py +2 -2
  287. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  288. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  289. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  290. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  291. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  292. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  293. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  294. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  295. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  296. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  297. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  298. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  299. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  300. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  301. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  302. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  303. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  304. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  305. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  306. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  307. diffusers/pipelines/pipeline_flax_utils.py +3 -4
  308. diffusers/pipelines/pipeline_loading_utils.py +89 -13
  309. diffusers/pipelines/pipeline_utils.py +105 -33
  310. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
  311. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
  312. diffusers/pipelines/sana/__init__.py +4 -0
  313. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  314. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  315. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  316. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  317. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  318. diffusers/pipelines/shap_e/camera.py +1 -1
  319. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  320. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  321. diffusers/pipelines/shap_e/renderer.py +3 -3
  322. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  323. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  324. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  325. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  326. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  327. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  328. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  329. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  330. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  331. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  332. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  333. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
  334. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  335. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
  336. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
  337. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
  338. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  339. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  340. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  341. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  342. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  343. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  344. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  345. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  346. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  347. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  348. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  349. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  350. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
  351. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  352. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  353. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  354. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  355. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  356. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  357. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  358. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  359. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  360. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  361. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  362. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  363. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  364. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  365. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  366. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  367. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  368. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  369. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  370. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  371. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  372. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  373. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  374. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  375. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  376. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  377. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  378. diffusers/pipelines/unclip/text_proj.py +2 -2
  379. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  380. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  381. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  382. diffusers/pipelines/visualcloze/__init__.py +52 -0
  383. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  384. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  385. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  386. diffusers/pipelines/wan/__init__.py +2 -0
  387. diffusers/pipelines/wan/pipeline_wan.py +17 -12
  388. diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
  389. diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
  390. diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
  391. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  392. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  393. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  394. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  395. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  396. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  397. diffusers/quantizers/__init__.py +179 -1
  398. diffusers/quantizers/base.py +6 -1
  399. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  400. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  401. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  402. diffusers/quantizers/gguf/utils.py +16 -13
  403. diffusers/quantizers/quantization_config.py +18 -16
  404. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  405. diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
  406. diffusers/schedulers/__init__.py +3 -1
  407. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  408. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  409. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  410. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  411. diffusers/schedulers/scheduling_ddim.py +8 -8
  412. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  413. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  414. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  415. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  416. diffusers/schedulers/scheduling_ddpm.py +9 -9
  417. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  418. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  419. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  420. diffusers/schedulers/scheduling_deis_multistep.py +8 -8
  421. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  422. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
  423. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  424. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  425. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  426. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
  427. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  428. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  429. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  430. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  431. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  432. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  433. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  434. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  435. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  436. diffusers/schedulers/scheduling_ipndm.py +2 -2
  437. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  438. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  439. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  440. diffusers/schedulers/scheduling_lcm.py +3 -3
  441. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  442. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  443. diffusers/schedulers/scheduling_pndm.py +4 -4
  444. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  445. diffusers/schedulers/scheduling_repaint.py +9 -9
  446. diffusers/schedulers/scheduling_sasolver.py +15 -15
  447. diffusers/schedulers/scheduling_scm.py +1 -1
  448. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  449. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  450. diffusers/schedulers/scheduling_tcd.py +3 -3
  451. diffusers/schedulers/scheduling_unclip.py +5 -5
  452. diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
  453. diffusers/schedulers/scheduling_utils.py +1 -1
  454. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  455. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  456. diffusers/training_utils.py +13 -5
  457. diffusers/utils/__init__.py +5 -0
  458. diffusers/utils/accelerate_utils.py +1 -1
  459. diffusers/utils/doc_utils.py +1 -1
  460. diffusers/utils/dummy_pt_objects.py +120 -0
  461. diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
  462. diffusers/utils/dynamic_modules_utils.py +21 -3
  463. diffusers/utils/export_utils.py +1 -1
  464. diffusers/utils/import_utils.py +81 -18
  465. diffusers/utils/logging.py +1 -1
  466. diffusers/utils/outputs.py +2 -1
  467. diffusers/utils/peft_utils.py +91 -8
  468. diffusers/utils/state_dict_utils.py +20 -3
  469. diffusers/utils/testing_utils.py +59 -7
  470. diffusers/utils/torch_utils.py +25 -5
  471. diffusers/video_processor.py +2 -2
  472. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
  473. diffusers-0.34.0.dist-info/RECORD +639 -0
  474. diffusers-0.33.0.dist-info/RECORD +0 -608
  475. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
  476. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
  477. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
  478. {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.33.0"
1
+ __version__ = "0.34.0"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -148,6 +148,7 @@ else:
148
148
  "AutoencoderKL",
149
149
  "AutoencoderKLAllegro",
150
150
  "AutoencoderKLCogVideoX",
151
+ "AutoencoderKLCosmos",
151
152
  "AutoencoderKLHunyuanVideo",
152
153
  "AutoencoderKLLTXVideo",
153
154
  "AutoencoderKLMagvit",
@@ -158,6 +159,7 @@ else:
158
159
  "AutoencoderTiny",
159
160
  "AutoModel",
160
161
  "CacheMixin",
162
+ "ChromaTransformer2DModel",
161
163
  "CogVideoXTransformer3DModel",
162
164
  "CogView3PlusTransformer2DModel",
163
165
  "CogView4Transformer2DModel",
@@ -166,14 +168,17 @@ else:
166
168
  "ControlNetModel",
167
169
  "ControlNetUnionModel",
168
170
  "ControlNetXSAdapter",
171
+ "CosmosTransformer3DModel",
169
172
  "DiTTransformer2DModel",
170
173
  "EasyAnimateTransformer3DModel",
171
174
  "FluxControlNetModel",
172
175
  "FluxMultiControlNetModel",
173
176
  "FluxTransformer2DModel",
177
+ "HiDreamImageTransformer2DModel",
174
178
  "HunyuanDiT2DControlNetModel",
175
179
  "HunyuanDiT2DModel",
176
180
  "HunyuanDiT2DMultiControlNetModel",
181
+ "HunyuanVideoFramepackTransformer3DModel",
177
182
  "HunyuanVideoTransformer3DModel",
178
183
  "I2VGenXLUNet",
179
184
  "Kandinsky3UNet",
@@ -189,6 +194,7 @@ else:
189
194
  "OmniGenTransformer2DModel",
190
195
  "PixArtTransformer2DModel",
191
196
  "PriorTransformer",
197
+ "SanaControlNetModel",
192
198
  "SanaTransformer2DModel",
193
199
  "SD3ControlNetModel",
194
200
  "SD3MultiControlNetModel",
@@ -210,6 +216,7 @@ else:
210
216
  "UVit2DModel",
211
217
  "VQModel",
212
218
  "WanTransformer3DModel",
219
+ "WanVACETransformer3DModel",
213
220
  ]
214
221
  )
215
222
  _import_structure["optimization"] = [
@@ -266,6 +273,7 @@ else:
266
273
  "EulerDiscreteScheduler",
267
274
  "FlowMatchEulerDiscreteScheduler",
268
275
  "FlowMatchHeunDiscreteScheduler",
276
+ "FlowMatchLCMScheduler",
269
277
  "HeunDiscreteScheduler",
270
278
  "IPNDMScheduler",
271
279
  "KarrasVeScheduler",
@@ -345,6 +353,8 @@ else:
345
353
  "AuraFlowPipeline",
346
354
  "BlipDiffusionControlNetPipeline",
347
355
  "BlipDiffusionPipeline",
356
+ "ChromaImg2ImgPipeline",
357
+ "ChromaPipeline",
348
358
  "CLIPImageProjection",
349
359
  "CogVideoXFunControlPipeline",
350
360
  "CogVideoXImageToVideoPipeline",
@@ -353,6 +363,11 @@ else:
353
363
  "CogView3PlusPipeline",
354
364
  "CogView4ControlPipeline",
355
365
  "CogView4Pipeline",
366
+ "ConsisIDPipeline",
367
+ "Cosmos2TextToImagePipeline",
368
+ "Cosmos2VideoToWorldPipeline",
369
+ "CosmosTextToWorldPipeline",
370
+ "CosmosVideoToWorldPipeline",
356
371
  "CycleDiffusionPipeline",
357
372
  "EasyAnimateControlPipeline",
358
373
  "EasyAnimateInpaintPipeline",
@@ -368,10 +383,12 @@ else:
368
383
  "FluxInpaintPipeline",
369
384
  "FluxPipeline",
370
385
  "FluxPriorReduxPipeline",
386
+ "HiDreamImagePipeline",
371
387
  "HunyuanDiTControlNetPipeline",
372
388
  "HunyuanDiTPAGPipeline",
373
389
  "HunyuanDiTPipeline",
374
390
  "HunyuanSkyreelsImageToVideoPipeline",
391
+ "HunyuanVideoFramepackPipeline",
375
392
  "HunyuanVideoImageToVideoPipeline",
376
393
  "HunyuanVideoPipeline",
377
394
  "I2VGenXLPipeline",
@@ -409,6 +426,7 @@ else:
409
426
  "LEditsPPPipelineStableDiffusionXL",
410
427
  "LTXConditionPipeline",
411
428
  "LTXImageToVideoPipeline",
429
+ "LTXLatentUpsamplePipeline",
412
430
  "LTXPipeline",
413
431
  "Lumina2Pipeline",
414
432
  "Lumina2Text2ImgPipeline",
@@ -426,8 +444,10 @@ else:
426
444
  "PixArtSigmaPAGPipeline",
427
445
  "PixArtSigmaPipeline",
428
446
  "ReduxImageEncoder",
447
+ "SanaControlNetPipeline",
429
448
  "SanaPAGPipeline",
430
449
  "SanaPipeline",
450
+ "SanaSprintImg2ImgPipeline",
431
451
  "SanaSprintPipeline",
432
452
  "SemanticStableDiffusionPipeline",
433
453
  "ShapEImg2ImgPipeline",
@@ -508,9 +528,12 @@ else:
508
528
  "VersatileDiffusionPipeline",
509
529
  "VersatileDiffusionTextToImagePipeline",
510
530
  "VideoToVideoSDPipeline",
531
+ "VisualClozeGenerationPipeline",
532
+ "VisualClozePipeline",
511
533
  "VQDiffusionPipeline",
512
534
  "WanImageToVideoPipeline",
513
535
  "WanPipeline",
536
+ "WanVACEPipeline",
514
537
  "WanVideoToVideoPipeline",
515
538
  "WuerstchenCombinedPipeline",
516
539
  "WuerstchenDecoderPipeline",
@@ -676,6 +699,7 @@ else:
676
699
 
677
700
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
678
701
  from .configuration_utils import ConfigMixin
702
+ from .quantizers import PipelineQuantizationConfig
679
703
 
680
704
  try:
681
705
  if not is_bitsandbytes_available():
@@ -738,6 +762,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
738
762
  AutoencoderKL,
739
763
  AutoencoderKLAllegro,
740
764
  AutoencoderKLCogVideoX,
765
+ AutoencoderKLCosmos,
741
766
  AutoencoderKLHunyuanVideo,
742
767
  AutoencoderKLLTXVideo,
743
768
  AutoencoderKLMagvit,
@@ -748,6 +773,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
748
773
  AutoencoderTiny,
749
774
  AutoModel,
750
775
  CacheMixin,
776
+ ChromaTransformer2DModel,
751
777
  CogVideoXTransformer3DModel,
752
778
  CogView3PlusTransformer2DModel,
753
779
  CogView4Transformer2DModel,
@@ -756,14 +782,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
756
782
  ControlNetModel,
757
783
  ControlNetUnionModel,
758
784
  ControlNetXSAdapter,
785
+ CosmosTransformer3DModel,
759
786
  DiTTransformer2DModel,
760
787
  EasyAnimateTransformer3DModel,
761
788
  FluxControlNetModel,
762
789
  FluxMultiControlNetModel,
763
790
  FluxTransformer2DModel,
791
+ HiDreamImageTransformer2DModel,
764
792
  HunyuanDiT2DControlNetModel,
765
793
  HunyuanDiT2DModel,
766
794
  HunyuanDiT2DMultiControlNetModel,
795
+ HunyuanVideoFramepackTransformer3DModel,
767
796
  HunyuanVideoTransformer3DModel,
768
797
  I2VGenXLUNet,
769
798
  Kandinsky3UNet,
@@ -779,6 +808,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
779
808
  OmniGenTransformer2DModel,
780
809
  PixArtTransformer2DModel,
781
810
  PriorTransformer,
811
+ SanaControlNetModel,
782
812
  SanaTransformer2DModel,
783
813
  SD3ControlNetModel,
784
814
  SD3MultiControlNetModel,
@@ -799,6 +829,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
799
829
  UVit2DModel,
800
830
  VQModel,
801
831
  WanTransformer3DModel,
832
+ WanVACETransformer3DModel,
802
833
  )
803
834
  from .optimization import (
804
835
  get_constant_schedule,
@@ -854,6 +885,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
854
885
  EulerDiscreteScheduler,
855
886
  FlowMatchEulerDiscreteScheduler,
856
887
  FlowMatchHeunDiscreteScheduler,
888
+ FlowMatchLCMScheduler,
857
889
  HeunDiscreteScheduler,
858
890
  IPNDMScheduler,
859
891
  KarrasVeScheduler,
@@ -914,6 +946,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
914
946
  AudioLDM2UNet2DConditionModel,
915
947
  AudioLDMPipeline,
916
948
  AuraFlowPipeline,
949
+ ChromaImg2ImgPipeline,
950
+ ChromaPipeline,
917
951
  CLIPImageProjection,
918
952
  CogVideoXFunControlPipeline,
919
953
  CogVideoXImageToVideoPipeline,
@@ -922,6 +956,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
922
956
  CogView3PlusPipeline,
923
957
  CogView4ControlPipeline,
924
958
  CogView4Pipeline,
959
+ ConsisIDPipeline,
960
+ Cosmos2TextToImagePipeline,
961
+ Cosmos2VideoToWorldPipeline,
962
+ CosmosTextToWorldPipeline,
963
+ CosmosVideoToWorldPipeline,
925
964
  CycleDiffusionPipeline,
926
965
  EasyAnimateControlPipeline,
927
966
  EasyAnimateInpaintPipeline,
@@ -937,10 +976,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
937
976
  FluxInpaintPipeline,
938
977
  FluxPipeline,
939
978
  FluxPriorReduxPipeline,
979
+ HiDreamImagePipeline,
940
980
  HunyuanDiTControlNetPipeline,
941
981
  HunyuanDiTPAGPipeline,
942
982
  HunyuanDiTPipeline,
943
983
  HunyuanSkyreelsImageToVideoPipeline,
984
+ HunyuanVideoFramepackPipeline,
944
985
  HunyuanVideoImageToVideoPipeline,
945
986
  HunyuanVideoPipeline,
946
987
  I2VGenXLPipeline,
@@ -978,6 +1019,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
978
1019
  LEditsPPPipelineStableDiffusionXL,
979
1020
  LTXConditionPipeline,
980
1021
  LTXImageToVideoPipeline,
1022
+ LTXLatentUpsamplePipeline,
981
1023
  LTXPipeline,
982
1024
  Lumina2Pipeline,
983
1025
  Lumina2Text2ImgPipeline,
@@ -995,8 +1037,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
995
1037
  PixArtSigmaPAGPipeline,
996
1038
  PixArtSigmaPipeline,
997
1039
  ReduxImageEncoder,
1040
+ SanaControlNetPipeline,
998
1041
  SanaPAGPipeline,
999
1042
  SanaPipeline,
1043
+ SanaSprintImg2ImgPipeline,
1000
1044
  SanaSprintPipeline,
1001
1045
  SemanticStableDiffusionPipeline,
1002
1046
  ShapEImg2ImgPipeline,
@@ -1076,9 +1120,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
1076
1120
  VersatileDiffusionPipeline,
1077
1121
  VersatileDiffusionTextToImagePipeline,
1078
1122
  VideoToVideoSDPipeline,
1123
+ VisualClozeGenerationPipeline,
1124
+ VisualClozePipeline,
1079
1125
  VQDiffusionPipeline,
1080
1126
  WanImageToVideoPipeline,
1081
1127
  WanPipeline,
1128
+ WanVACEPipeline,
1082
1129
  WanVideoToVideoPipeline,
1083
1130
  WuerstchenCombinedPipeline,
1084
1131
  WuerstchenDecoderPipeline,
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,5 +1,5 @@
1
1
  #!/usr/bin/env python
2
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
diffusers/commands/env.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ deps = {
23
23
  "librosa": "librosa",
24
24
  "numpy": "numpy",
25
25
  "parameterized": "parameterized",
26
- "peft": "peft>=0.6.0",
26
+ "peft": "peft>=0.15.0",
27
27
  "protobuf": "protobuf>=3.20.3,<4",
28
28
  "pytest": "pytest",
29
29
  "pytest-timeout": "pytest-timeout",
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -146,7 +146,7 @@ class FasterCacheConfig:
146
146
  alpha_low_frequency: float = 1.1
147
147
  alpha_high_frequency: float = 1.1
148
148
 
149
- # n as described in CFG-Cache explanation in the paper - dependant on the model
149
+ # n as described in CFG-Cache explanation in the paper - dependent on the model
150
150
  unconditional_batch_skip_range: int = 5
151
151
  unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
152
152
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,9 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import os
15
16
  from contextlib import contextmanager, nullcontext
16
- from typing import Dict, List, Optional, Set, Tuple
17
+ from typing import Dict, List, Optional, Set, Tuple, Union
17
18
 
19
+ import safetensors.torch
18
20
  import torch
19
21
 
20
22
  from ..utils import get_logger, is_accelerate_available
@@ -55,10 +57,11 @@ class ModuleGroup:
55
57
  parameters: Optional[List[torch.nn.Parameter]] = None,
56
58
  buffers: Optional[List[torch.Tensor]] = None,
57
59
  non_blocking: bool = False,
58
- stream: Optional[torch.cuda.Stream] = None,
60
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
59
61
  record_stream: Optional[bool] = False,
60
- low_cpu_mem_usage=False,
62
+ low_cpu_mem_usage: bool = False,
61
63
  onload_self: bool = True,
64
+ offload_to_disk_path: Optional[str] = None,
62
65
  ) -> None:
63
66
  self.modules = modules
64
67
  self.offload_device = offload_device
@@ -72,10 +75,26 @@ class ModuleGroup:
72
75
  self.record_stream = record_stream
73
76
  self.onload_self = onload_self
74
77
  self.low_cpu_mem_usage = low_cpu_mem_usage
75
- self.cpu_param_dict = self._init_cpu_param_dict()
76
78
 
77
- if self.stream is None and self.record_stream:
78
- raise ValueError("`record_stream` cannot be True when `stream` is None.")
79
+ self.offload_to_disk_path = offload_to_disk_path
80
+ self._is_offloaded_to_disk = False
81
+
82
+ if self.offload_to_disk_path:
83
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
84
+
85
+ all_tensors = []
86
+ for module in self.modules:
87
+ all_tensors.extend(list(module.parameters()))
88
+ all_tensors.extend(list(module.buffers()))
89
+ all_tensors.extend(self.parameters)
90
+ all_tensors.extend(self.buffers)
91
+ all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
92
+
93
+ self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
94
+ self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
95
+ self.cpu_param_dict = {}
96
+ else:
97
+ self.cpu_param_dict = self._init_cpu_param_dict()
79
98
 
80
99
  def _init_cpu_param_dict(self):
81
100
  cpu_param_dict = {}
@@ -113,10 +132,40 @@ class ModuleGroup:
113
132
  finally:
114
133
  pinned_dict = None
115
134
 
135
+ @torch.compiler.disable()
116
136
  def onload_(self):
117
137
  r"""Onloads the group of modules to the onload_device."""
118
- context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
119
- current_stream = torch.cuda.current_stream() if self.record_stream else None
138
+ torch_accelerator_module = (
139
+ getattr(torch, torch.accelerator.current_accelerator().type)
140
+ if hasattr(torch, "accelerator")
141
+ else torch.cuda
142
+ )
143
+ context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
144
+ current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
145
+
146
+ if self.offload_to_disk_path:
147
+ if self.stream is not None:
148
+ # Wait for previous Host->Device transfer to complete
149
+ self.stream.synchronize()
150
+
151
+ with context:
152
+ if self.stream is not None:
153
+ # Load to CPU, pin, and async copy to device for overlapping transfer and compute
154
+ loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
155
+ for key, tensor_obj in self.key_to_tensor.items():
156
+ pinned_tensor = loaded_cpu_tensors[key].pin_memory()
157
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158
+ if self.record_stream:
159
+ tensor_obj.data.record_stream(current_stream)
160
+ else:
161
+ # Load directly to the target device (synchronous)
162
+ onload_device = (
163
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
164
+ )
165
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
166
+ for key, tensor_obj in self.key_to_tensor.items():
167
+ tensor_obj.data = loaded_tensors[key]
168
+ return
120
169
 
121
170
  if self.stream is not None:
122
171
  # Wait for previous Host->Device transfer to complete
@@ -160,11 +209,38 @@ class ModuleGroup:
160
209
  if self.record_stream:
161
210
  buffer.data.record_stream(current_stream)
162
211
 
212
+ @torch.compiler.disable()
163
213
  def offload_(self):
164
214
  r"""Offloads the group of modules to the offload_device."""
215
+ if self.offload_to_disk_path:
216
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
217
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
218
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
219
+ # we perform a write.
220
+ # Check if the file has been saved in this session or if it already exists on disk.
221
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
222
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
223
+ tensors_to_save = {
224
+ key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
225
+ }
226
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
227
+
228
+ # The group is now considered offloaded to disk for the rest of the session.
229
+ self._is_offloaded_to_disk = True
230
+
231
+ # We do this to free up the RAM which is still holding the up tensor data.
232
+ for tensor_obj in self.tensor_to_key.keys():
233
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
234
+ return
235
+
236
+ torch_accelerator_module = (
237
+ getattr(torch, torch.accelerator.current_accelerator().type)
238
+ if hasattr(torch, "accelerator")
239
+ else torch.cuda
240
+ )
165
241
  if self.stream is not None:
166
242
  if not self.record_stream:
167
- torch.cuda.current_stream().synchronize()
243
+ torch_accelerator_module.current_stream().synchronize()
168
244
  for group_module in self.modules:
169
245
  for param in group_module.parameters():
170
246
  param.data = self.cpu_param_dict[param]
@@ -192,11 +268,7 @@ class GroupOffloadingHook(ModelHook):
192
268
 
193
269
  _is_stateful = False
194
270
 
195
- def __init__(
196
- self,
197
- group: ModuleGroup,
198
- next_group: Optional[ModuleGroup] = None,
199
- ) -> None:
271
+ def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
200
272
  self.group = group
201
273
  self.next_group = next_group
202
274
 
@@ -232,7 +304,7 @@ class GroupOffloadingHook(ModelHook):
232
304
 
233
305
  class LazyPrefetchGroupOffloadingHook(ModelHook):
234
306
  r"""
235
- A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
307
+ A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
236
308
  This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
237
309
  invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
238
310
  prefetching groups in the correct order.
@@ -350,6 +422,7 @@ def apply_group_offloading(
350
422
  use_stream: bool = False,
351
423
  record_stream: bool = False,
352
424
  low_cpu_mem_usage: bool = False,
425
+ offload_to_disk_path: Optional[str] = None,
353
426
  ) -> None:
354
427
  r"""
355
428
  Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -388,6 +461,9 @@ def apply_group_offloading(
388
461
  offload_type (`str`, defaults to "block_level"):
389
462
  The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
390
463
  "block_level".
464
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
465
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
466
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
391
467
  num_blocks_per_group (`int`, *optional*):
392
468
  The number of blocks per group when using offload_type="block_level". This is required when using
393
469
  offload_type="block_level".
@@ -429,8 +505,13 @@ def apply_group_offloading(
429
505
  if use_stream:
430
506
  if torch.cuda.is_available():
431
507
  stream = torch.cuda.Stream()
508
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
509
+ stream = torch.Stream()
432
510
  else:
433
- raise ValueError("Using streams for data transfer requires a CUDA device.")
511
+ raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
512
+
513
+ if not use_stream and record_stream:
514
+ raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
434
515
 
435
516
  _raise_error_if_accelerate_model_or_sequential_hook_present(module)
436
517
 
@@ -443,6 +524,7 @@ def apply_group_offloading(
443
524
  num_blocks_per_group=num_blocks_per_group,
444
525
  offload_device=offload_device,
445
526
  onload_device=onload_device,
527
+ offload_to_disk_path=offload_to_disk_path,
446
528
  non_blocking=non_blocking,
447
529
  stream=stream,
448
530
  record_stream=record_stream,
@@ -453,6 +535,7 @@ def apply_group_offloading(
453
535
  module=module,
454
536
  offload_device=offload_device,
455
537
  onload_device=onload_device,
538
+ offload_to_disk_path=offload_to_disk_path,
456
539
  non_blocking=non_blocking,
457
540
  stream=stream,
458
541
  record_stream=record_stream,
@@ -468,9 +551,10 @@ def _apply_group_offloading_block_level(
468
551
  offload_device: torch.device,
469
552
  onload_device: torch.device,
470
553
  non_blocking: bool,
471
- stream: Optional[torch.cuda.Stream] = None,
554
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
472
555
  record_stream: Optional[bool] = False,
473
556
  low_cpu_mem_usage: bool = False,
557
+ offload_to_disk_path: Optional[str] = None,
474
558
  ) -> None:
475
559
  r"""
476
560
  This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -481,12 +565,15 @@ def _apply_group_offloading_block_level(
481
565
  The module to which group offloading is applied.
482
566
  offload_device (`torch.device`):
483
567
  The device to which the group of modules are offloaded. This should typically be the CPU.
568
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
569
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
570
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
484
571
  onload_device (`torch.device`):
485
572
  The device to which the group of modules are onloaded.
486
573
  non_blocking (`bool`):
487
574
  If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
488
575
  and data transfer.
489
- stream (`torch.cuda.Stream`, *optional*):
576
+ stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
490
577
  If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
491
578
  for overlapping computation and data transfer.
492
579
  record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -498,6 +585,11 @@ def _apply_group_offloading_block_level(
498
585
  option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499
586
  the CPU memory is a bottleneck but may counteract the benefits of using streams.
500
587
  """
588
+ if stream is not None and num_blocks_per_group != 1:
589
+ logger.warning(
590
+ f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
591
+ )
592
+ num_blocks_per_group = 1
501
593
 
502
594
  # Create module groups for ModuleList and Sequential blocks
503
595
  modules_with_group_offloading = set()
@@ -515,13 +607,14 @@ def _apply_group_offloading_block_level(
515
607
  modules=current_modules,
516
608
  offload_device=offload_device,
517
609
  onload_device=onload_device,
610
+ offload_to_disk_path=offload_to_disk_path,
518
611
  offload_leader=current_modules[-1],
519
612
  onload_leader=current_modules[0],
520
613
  non_blocking=non_blocking,
521
614
  stream=stream,
522
615
  record_stream=record_stream,
523
616
  low_cpu_mem_usage=low_cpu_mem_usage,
524
- onload_self=stream is None,
617
+ onload_self=True,
525
618
  )
526
619
  matched_module_groups.append(group)
527
620
  for j in range(i, i + len(current_modules)):
@@ -529,12 +622,8 @@ def _apply_group_offloading_block_level(
529
622
 
530
623
  # Apply group offloading hooks to the module groups
531
624
  for i, group in enumerate(matched_module_groups):
532
- next_group = (
533
- matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
534
- )
535
-
536
625
  for group_module in group.modules:
537
- _apply_group_offloading_hook(group_module, group, next_group)
626
+ _apply_group_offloading_hook(group_module, group, None)
538
627
 
539
628
  # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
540
629
  # when the forward pass of this module is called. This is because the top-level module is not
@@ -551,6 +640,7 @@ def _apply_group_offloading_block_level(
551
640
  modules=unmatched_modules,
552
641
  offload_device=offload_device,
553
642
  onload_device=onload_device,
643
+ offload_to_disk_path=offload_to_disk_path,
554
644
  offload_leader=module,
555
645
  onload_leader=module,
556
646
  parameters=parameters,
@@ -560,8 +650,10 @@ def _apply_group_offloading_block_level(
560
650
  record_stream=False,
561
651
  onload_self=True,
562
652
  )
563
- next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
564
- _apply_group_offloading_hook(module, unmatched_group, next_group)
653
+ if stream is None:
654
+ _apply_group_offloading_hook(module, unmatched_group, None)
655
+ else:
656
+ _apply_lazy_group_offloading_hook(module, unmatched_group, None)
565
657
 
566
658
 
567
659
  def _apply_group_offloading_leaf_level(
@@ -569,9 +661,10 @@ def _apply_group_offloading_leaf_level(
569
661
  offload_device: torch.device,
570
662
  onload_device: torch.device,
571
663
  non_blocking: bool,
572
- stream: Optional[torch.cuda.Stream] = None,
664
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
573
665
  record_stream: Optional[bool] = False,
574
666
  low_cpu_mem_usage: bool = False,
667
+ offload_to_disk_path: Optional[str] = None,
575
668
  ) -> None:
576
669
  r"""
577
670
  This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -586,10 +679,13 @@ def _apply_group_offloading_leaf_level(
586
679
  The device to which the group of modules are offloaded. This should typically be the CPU.
587
680
  onload_device (`torch.device`):
588
681
  The device to which the group of modules are onloaded.
682
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
683
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
684
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
589
685
  non_blocking (`bool`):
590
686
  If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
591
687
  and data transfer.
592
- stream (`torch.cuda.Stream`, *optional*):
688
+ stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
593
689
  If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
594
690
  for overlapping computation and data transfer.
595
691
  record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -611,6 +707,7 @@ def _apply_group_offloading_leaf_level(
611
707
  modules=[submodule],
612
708
  offload_device=offload_device,
613
709
  onload_device=onload_device,
710
+ offload_to_disk_path=offload_to_disk_path,
614
711
  offload_leader=submodule,
615
712
  onload_leader=submodule,
616
713
  non_blocking=non_blocking,
@@ -657,6 +754,7 @@ def _apply_group_offloading_leaf_level(
657
754
  onload_device=onload_device,
658
755
  offload_leader=parent_module,
659
756
  onload_leader=parent_module,
757
+ offload_to_disk_path=offload_to_disk_path,
660
758
  parameters=parameters,
661
759
  buffers=buffers,
662
760
  non_blocking=non_blocking,
@@ -675,6 +773,7 @@ def _apply_group_offloading_leaf_level(
675
773
  modules=[],
676
774
  offload_device=offload_device,
677
775
  onload_device=onload_device,
776
+ offload_to_disk_path=offload_to_disk_path,
678
777
  offload_leader=module,
679
778
  onload_leader=module,
680
779
  parameters=None,