diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -12,12 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import hashlib
16
+ import os
15
17
  from contextlib import contextmanager, nullcontext
16
- from typing import Dict, List, Optional, Set, Tuple
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
17
21
 
22
+ import safetensors.torch
18
23
  import torch
19
24
 
20
25
  from ..utils import get_logger, is_accelerate_available
26
+ from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
21
27
  from .hooks import HookRegistry, ModelHook
22
28
 
23
29
 
@@ -33,17 +39,28 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
33
39
  _GROUP_OFFLOADING = "group_offloading"
34
40
  _LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
35
41
  _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
36
-
37
- _SUPPORTED_PYTORCH_LAYERS = (
38
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
39
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
40
- torch.nn.Linear,
41
- # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
42
- # because of double invocation of the same norm layer in CogVideoXLayerNorm
43
- )
42
+ _GROUP_ID_LAZY_LEAF = "lazy_leafs"
44
43
  # fmt: on
45
44
 
46
45
 
46
+ class GroupOffloadingType(str, Enum):
47
+ BLOCK_LEVEL = "block_level"
48
+ LEAF_LEVEL = "leaf_level"
49
+
50
+
51
+ @dataclass
52
+ class GroupOffloadingConfig:
53
+ onload_device: torch.device
54
+ offload_device: torch.device
55
+ offload_type: GroupOffloadingType
56
+ non_blocking: bool
57
+ record_stream: bool
58
+ low_cpu_mem_usage: bool
59
+ num_blocks_per_group: Optional[int] = None
60
+ offload_to_disk_path: Optional[str] = None
61
+ stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62
+
63
+
47
64
  class ModuleGroup:
48
65
  def __init__(
49
66
  self,
@@ -55,10 +72,12 @@ class ModuleGroup:
55
72
  parameters: Optional[List[torch.nn.Parameter]] = None,
56
73
  buffers: Optional[List[torch.Tensor]] = None,
57
74
  non_blocking: bool = False,
58
- stream: Optional[torch.cuda.Stream] = None,
75
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
59
76
  record_stream: Optional[bool] = False,
60
- low_cpu_mem_usage=False,
77
+ low_cpu_mem_usage: bool = False,
61
78
  onload_self: bool = True,
79
+ offload_to_disk_path: Optional[str] = None,
80
+ group_id: Optional[int] = None,
62
81
  ) -> None:
63
82
  self.modules = modules
64
83
  self.offload_device = offload_device
@@ -72,10 +91,35 @@ class ModuleGroup:
72
91
  self.record_stream = record_stream
73
92
  self.onload_self = onload_self
74
93
  self.low_cpu_mem_usage = low_cpu_mem_usage
75
- self.cpu_param_dict = self._init_cpu_param_dict()
76
94
 
77
- if self.stream is None and self.record_stream:
78
- raise ValueError("`record_stream` cannot be True when `stream` is None.")
95
+ self.offload_to_disk_path = offload_to_disk_path
96
+ self._is_offloaded_to_disk = False
97
+
98
+ if self.offload_to_disk_path is not None:
99
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100
+ self.group_id = group_id if group_id is not None else str(id(self))
101
+ short_hash = _compute_group_hash(self.group_id)
102
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
103
+
104
+ all_tensors = []
105
+ for module in self.modules:
106
+ all_tensors.extend(list(module.parameters()))
107
+ all_tensors.extend(list(module.buffers()))
108
+ all_tensors.extend(self.parameters)
109
+ all_tensors.extend(self.buffers)
110
+ all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
111
+
112
+ self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
113
+ self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
114
+ self.cpu_param_dict = {}
115
+ else:
116
+ self.cpu_param_dict = self._init_cpu_param_dict()
117
+
118
+ self._torch_accelerator_module = (
119
+ getattr(torch, torch.accelerator.current_accelerator().type)
120
+ if hasattr(torch, "accelerator")
121
+ else torch.cuda
122
+ )
79
123
 
80
124
  def _init_cpu_param_dict(self):
81
125
  cpu_param_dict = {}
@@ -100,71 +144,100 @@ class ModuleGroup:
100
144
 
101
145
  @contextmanager
102
146
  def _pinned_memory_tensors(self):
103
- pinned_dict = {}
104
147
  try:
105
- for param, tensor in self.cpu_param_dict.items():
106
- if not tensor.is_pinned():
107
- pinned_dict[param] = tensor.pin_memory()
108
- else:
109
- pinned_dict[param] = tensor
110
-
148
+ pinned_dict = {
149
+ param: tensor.pin_memory() if not tensor.is_pinned() else tensor
150
+ for param, tensor in self.cpu_param_dict.items()
151
+ }
111
152
  yield pinned_dict
112
-
113
153
  finally:
114
154
  pinned_dict = None
115
155
 
116
- def onload_(self):
117
- r"""Onloads the group of modules to the onload_device."""
118
- context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
119
- current_stream = torch.cuda.current_stream() if self.record_stream else None
156
+ def _transfer_tensor_to_device(self, tensor, source_tensor):
157
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158
+ if self.record_stream:
159
+ tensor.data.record_stream(self._torch_accelerator_module.current_stream())
160
+
161
+ def _process_tensors_from_modules(self, pinned_memory=None):
162
+ for group_module in self.modules:
163
+ for param in group_module.parameters():
164
+ source = pinned_memory[param] if pinned_memory else param.data
165
+ self._transfer_tensor_to_device(param, source)
166
+ for buffer in group_module.buffers():
167
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
168
+ self._transfer_tensor_to_device(buffer, source)
169
+
170
+ for param in self.parameters:
171
+ source = pinned_memory[param] if pinned_memory else param.data
172
+ self._transfer_tensor_to_device(param, source)
173
+
174
+ for buffer in self.buffers:
175
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
176
+ self._transfer_tensor_to_device(buffer, source)
120
177
 
178
+ def _onload_from_disk(self):
121
179
  if self.stream is not None:
122
180
  # Wait for previous Host->Device transfer to complete
123
181
  self.stream.synchronize()
124
182
 
183
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
184
+ current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
185
+
125
186
  with context:
126
- if self.stream is not None:
127
- with self._pinned_memory_tensors() as pinned_memory:
128
- for group_module in self.modules:
129
- for param in group_module.parameters():
130
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
131
- if self.record_stream:
132
- param.data.record_stream(current_stream)
133
- for buffer in group_module.buffers():
134
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
135
- if self.record_stream:
136
- buffer.data.record_stream(current_stream)
137
-
138
- for param in self.parameters:
139
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
140
- if self.record_stream:
141
- param.data.record_stream(current_stream)
142
-
143
- for buffer in self.buffers:
144
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
145
- if self.record_stream:
146
- buffer.data.record_stream(current_stream)
187
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188
+ device = str(self.onload_device) if self.stream is None else "cpu"
189
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
147
190
 
191
+ if self.stream is not None:
192
+ for key, tensor_obj in self.key_to_tensor.items():
193
+ pinned_tensor = loaded_tensors[key].pin_memory()
194
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
195
+ if self.record_stream:
196
+ tensor_obj.data.record_stream(current_stream)
148
197
  else:
149
- for group_module in self.modules:
150
- for param in group_module.parameters():
151
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
152
- for buffer in group_module.buffers():
153
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
154
-
155
- for param in self.parameters:
156
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
198
+ onload_device = (
199
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
200
+ )
201
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
202
+ for key, tensor_obj in self.key_to_tensor.items():
203
+ tensor_obj.data = loaded_tensors[key]
157
204
 
158
- for buffer in self.buffers:
159
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
160
- if self.record_stream:
161
- buffer.data.record_stream(current_stream)
205
+ def _onload_from_memory(self):
206
+ if self.stream is not None:
207
+ # Wait for previous Host->Device transfer to complete
208
+ self.stream.synchronize()
162
209
 
163
- def offload_(self):
164
- r"""Offloads the group of modules to the offload_device."""
210
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
211
+ with context:
212
+ if self.stream is not None:
213
+ with self._pinned_memory_tensors() as pinned_memory:
214
+ self._process_tensors_from_modules(pinned_memory)
215
+ else:
216
+ self._process_tensors_from_modules(None)
217
+
218
+ def _offload_to_disk(self):
219
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222
+ # we perform a write.
223
+ # Check if the file has been saved in this session or if it already exists on disk.
224
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
226
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
227
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
228
+
229
+ # The group is now considered offloaded to disk for the rest of the session.
230
+ self._is_offloaded_to_disk = True
231
+
232
+ # We do this to free up the RAM which is still holding the up tensor data.
233
+ for tensor_obj in self.tensor_to_key.keys():
234
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
235
+
236
+ def _offload_to_memory(self):
165
237
  if self.stream is not None:
166
238
  if not self.record_stream:
167
- torch.cuda.current_stream().synchronize()
239
+ self._torch_accelerator_module.current_stream().synchronize()
240
+
168
241
  for group_module in self.modules:
169
242
  for param in group_module.parameters():
170
243
  param.data = self.cpu_param_dict[param]
@@ -172,14 +245,29 @@ class ModuleGroup:
172
245
  param.data = self.cpu_param_dict[param]
173
246
  for buffer in self.buffers:
174
247
  buffer.data = self.cpu_param_dict[buffer]
175
-
176
248
  else:
177
249
  for group_module in self.modules:
178
- group_module.to(self.offload_device, non_blocking=self.non_blocking)
250
+ group_module.to(self.offload_device, non_blocking=False)
179
251
  for param in self.parameters:
180
- param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
252
+ param.data = param.data.to(self.offload_device, non_blocking=False)
181
253
  for buffer in self.buffers:
182
- buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
254
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
255
+
256
+ @torch.compiler.disable()
257
+ def onload_(self):
258
+ r"""Onloads the group of parameters to the onload_device."""
259
+ if self.offload_to_disk_path is not None:
260
+ self._onload_from_disk()
261
+ else:
262
+ self._onload_from_memory()
263
+
264
+ @torch.compiler.disable()
265
+ def offload_(self):
266
+ r"""Offloads the group of parameters to the offload_device."""
267
+ if self.offload_to_disk_path:
268
+ self._offload_to_disk()
269
+ else:
270
+ self._offload_to_memory()
183
271
 
184
272
 
185
273
  class GroupOffloadingHook(ModelHook):
@@ -192,13 +280,10 @@ class GroupOffloadingHook(ModelHook):
192
280
 
193
281
  _is_stateful = False
194
282
 
195
- def __init__(
196
- self,
197
- group: ModuleGroup,
198
- next_group: Optional[ModuleGroup] = None,
199
- ) -> None:
283
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
200
284
  self.group = group
201
- self.next_group = next_group
285
+ self.next_group: Optional[ModuleGroup] = None
286
+ self.config = config
202
287
 
203
288
  def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
204
289
  if self.group.offload_leader == module:
@@ -217,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
217
302
  if self.group.onload_leader == module:
218
303
  if self.group.onload_self:
219
304
  self.group.onload_()
220
- if self.next_group is not None and not self.next_group.onload_self:
305
+
306
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307
+ if should_onload_next_group:
221
308
  self.next_group.onload_()
222
309
 
310
+ should_synchronize = (
311
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312
+ )
313
+ if should_synchronize:
314
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315
+ # previous group. We need to synchronize the side stream to ensure parameters
316
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317
+ # weights will be used in the computation, leading to incorrect results
318
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320
+ self.group.stream.synchronize()
321
+
223
322
  args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
224
323
  kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
225
324
  return args, kwargs
@@ -232,7 +331,7 @@ class GroupOffloadingHook(ModelHook):
232
331
 
233
332
  class LazyPrefetchGroupOffloadingHook(ModelHook):
234
333
  r"""
235
- A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
334
+ A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
236
335
  This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
237
336
  invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
238
337
  prefetching groups in the correct order.
@@ -247,7 +346,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
247
346
  def initialize_hook(self, module):
248
347
  def make_execution_order_update_callback(current_name, current_submodule):
249
348
  def callback():
250
- logger.debug(f"Adding {current_name} to the execution order")
349
+ if not torch.compiler.is_compiling():
350
+ logger.debug(f"Adding {current_name} to the execution order")
251
351
  self.execution_order.append((current_name, current_submodule))
252
352
 
253
353
  return callback
@@ -284,12 +384,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
284
384
  # if the missing layers end up being executed in the future.
285
385
  if execution_order_module_names != self._layer_execution_tracker_module_names:
286
386
  unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
287
- logger.warning(
288
- "It seems like some layers were not executed during the forward pass. This may lead to problems when "
289
- "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
290
- "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
291
- f"{unexecuted_layers=}"
292
- )
387
+ if not torch.compiler.is_compiling():
388
+ logger.warning(
389
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
390
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
391
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
392
+ f"{unexecuted_layers=}"
393
+ )
293
394
 
294
395
  # Remove the layer execution tracker hooks from the submodules
295
396
  base_module_registry = module._diffusers_hook
@@ -317,7 +418,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
317
418
  for i in range(num_executed - 1):
318
419
  name1, _ = self.execution_order[i]
319
420
  name2, _ = self.execution_order[i + 1]
320
- logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
421
+ if not torch.compiler.is_compiling():
422
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
321
423
  group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
322
424
  group_offloading_hooks[i].next_group.onload_self = False
323
425
 
@@ -342,14 +444,15 @@ class LayerExecutionTrackerHook(ModelHook):
342
444
 
343
445
  def apply_group_offloading(
344
446
  module: torch.nn.Module,
345
- onload_device: torch.device,
346
- offload_device: torch.device = torch.device("cpu"),
347
- offload_type: str = "block_level",
447
+ onload_device: Union[str, torch.device],
448
+ offload_device: Union[str, torch.device] = torch.device("cpu"),
449
+ offload_type: Union[str, GroupOffloadingType] = "block_level",
348
450
  num_blocks_per_group: Optional[int] = None,
349
451
  non_blocking: bool = False,
350
452
  use_stream: bool = False,
351
453
  record_stream: bool = False,
352
454
  low_cpu_mem_usage: bool = False,
455
+ offload_to_disk_path: Optional[str] = None,
353
456
  ) -> None:
354
457
  r"""
355
458
  Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -385,9 +488,12 @@ def apply_group_offloading(
385
488
  The device to which the group of modules are onloaded.
386
489
  offload_device (`torch.device`, defaults to `torch.device("cpu")`):
387
490
  The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
388
- offload_type (`str`, defaults to "block_level"):
491
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
389
492
  The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
390
493
  "block_level".
494
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
495
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
496
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
391
497
  num_blocks_per_group (`int`, *optional*):
392
498
  The number of blocks per group when using offload_type="block_level". This is required when using
393
499
  offload_type="block_level".
@@ -425,80 +531,61 @@ def apply_group_offloading(
425
531
  ```
426
532
  """
427
533
 
534
+ onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
535
+ offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
536
+ offload_type = GroupOffloadingType(offload_type)
537
+
428
538
  stream = None
429
539
  if use_stream:
430
540
  if torch.cuda.is_available():
431
541
  stream = torch.cuda.Stream()
542
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
543
+ stream = torch.Stream()
432
544
  else:
433
- raise ValueError("Using streams for data transfer requires a CUDA device.")
545
+ raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
546
+
547
+ if not use_stream and record_stream:
548
+ raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
549
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
550
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
434
551
 
435
552
  _raise_error_if_accelerate_model_or_sequential_hook_present(module)
436
553
 
437
- if offload_type == "block_level":
438
- if num_blocks_per_group is None:
439
- raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
440
-
441
- _apply_group_offloading_block_level(
442
- module=module,
443
- num_blocks_per_group=num_blocks_per_group,
444
- offload_device=offload_device,
445
- onload_device=onload_device,
446
- non_blocking=non_blocking,
447
- stream=stream,
448
- record_stream=record_stream,
449
- low_cpu_mem_usage=low_cpu_mem_usage,
450
- )
451
- elif offload_type == "leaf_level":
452
- _apply_group_offloading_leaf_level(
453
- module=module,
454
- offload_device=offload_device,
455
- onload_device=onload_device,
456
- non_blocking=non_blocking,
457
- stream=stream,
458
- record_stream=record_stream,
459
- low_cpu_mem_usage=low_cpu_mem_usage,
460
- )
554
+ config = GroupOffloadingConfig(
555
+ onload_device=onload_device,
556
+ offload_device=offload_device,
557
+ offload_type=offload_type,
558
+ num_blocks_per_group=num_blocks_per_group,
559
+ non_blocking=non_blocking,
560
+ stream=stream,
561
+ record_stream=record_stream,
562
+ low_cpu_mem_usage=low_cpu_mem_usage,
563
+ offload_to_disk_path=offload_to_disk_path,
564
+ )
565
+ _apply_group_offloading(module, config)
566
+
567
+
568
+ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
569
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
570
+ _apply_group_offloading_block_level(module, config)
571
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
572
+ _apply_group_offloading_leaf_level(module, config)
461
573
  else:
462
- raise ValueError(f"Unsupported offload_type: {offload_type}")
574
+ assert False
463
575
 
464
576
 
465
- def _apply_group_offloading_block_level(
466
- module: torch.nn.Module,
467
- num_blocks_per_group: int,
468
- offload_device: torch.device,
469
- onload_device: torch.device,
470
- non_blocking: bool,
471
- stream: Optional[torch.cuda.Stream] = None,
472
- record_stream: Optional[bool] = False,
473
- low_cpu_mem_usage: bool = False,
474
- ) -> None:
577
+ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
475
578
  r"""
476
579
  This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
477
580
  the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
478
-
479
- Args:
480
- module (`torch.nn.Module`):
481
- The module to which group offloading is applied.
482
- offload_device (`torch.device`):
483
- The device to which the group of modules are offloaded. This should typically be the CPU.
484
- onload_device (`torch.device`):
485
- The device to which the group of modules are onloaded.
486
- non_blocking (`bool`):
487
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
488
- and data transfer.
489
- stream (`torch.cuda.Stream`, *optional*):
490
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
491
- for overlapping computation and data transfer.
492
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
493
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
494
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
495
- details.
496
- low_cpu_mem_usage (`bool`, defaults to `False`):
497
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
498
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
499
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
500
581
  """
501
582
 
583
+ if config.stream is not None and config.num_blocks_per_group != 1:
584
+ logger.warning(
585
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
586
+ )
587
+ config.num_blocks_per_group = 1
588
+
502
589
  # Create module groups for ModuleList and Sequential blocks
503
590
  modules_with_group_offloading = set()
504
591
  unmatched_modules = []
@@ -509,19 +596,22 @@ def _apply_group_offloading_block_level(
509
596
  modules_with_group_offloading.add(name)
510
597
  continue
511
598
 
512
- for i in range(0, len(submodule), num_blocks_per_group):
513
- current_modules = submodule[i : i + num_blocks_per_group]
599
+ for i in range(0, len(submodule), config.num_blocks_per_group):
600
+ current_modules = submodule[i : i + config.num_blocks_per_group]
601
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
514
602
  group = ModuleGroup(
515
603
  modules=current_modules,
516
- offload_device=offload_device,
517
- onload_device=onload_device,
604
+ offload_device=config.offload_device,
605
+ onload_device=config.onload_device,
606
+ offload_to_disk_path=config.offload_to_disk_path,
518
607
  offload_leader=current_modules[-1],
519
608
  onload_leader=current_modules[0],
520
- non_blocking=non_blocking,
521
- stream=stream,
522
- record_stream=record_stream,
523
- low_cpu_mem_usage=low_cpu_mem_usage,
524
- onload_self=stream is None,
609
+ non_blocking=config.non_blocking,
610
+ stream=config.stream,
611
+ record_stream=config.record_stream,
612
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
613
+ onload_self=True,
614
+ group_id=group_id,
525
615
  )
526
616
  matched_module_groups.append(group)
527
617
  for j in range(i, i + len(current_modules)):
@@ -529,12 +619,8 @@ def _apply_group_offloading_block_level(
529
619
 
530
620
  # Apply group offloading hooks to the module groups
531
621
  for i, group in enumerate(matched_module_groups):
532
- next_group = (
533
- matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
534
- )
535
-
536
622
  for group_module in group.modules:
537
- _apply_group_offloading_hook(group_module, group, next_group)
623
+ _apply_group_offloading_hook(group_module, group, config=config)
538
624
 
539
625
  # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
540
626
  # when the forward pass of this module is called. This is because the top-level module is not
@@ -549,8 +635,9 @@ def _apply_group_offloading_block_level(
549
635
  unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
550
636
  unmatched_group = ModuleGroup(
551
637
  modules=unmatched_modules,
552
- offload_device=offload_device,
553
- onload_device=onload_device,
638
+ offload_device=config.offload_device,
639
+ onload_device=config.onload_device,
640
+ offload_to_disk_path=config.offload_to_disk_path,
554
641
  offload_leader=module,
555
642
  onload_leader=module,
556
643
  parameters=parameters,
@@ -559,67 +646,41 @@ def _apply_group_offloading_block_level(
559
646
  stream=None,
560
647
  record_stream=False,
561
648
  onload_self=True,
649
+ group_id=f"{module.__class__.__name__}_unmatched_group",
562
650
  )
563
- next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
564
- _apply_group_offloading_hook(module, unmatched_group, next_group)
651
+ if config.stream is None:
652
+ _apply_group_offloading_hook(module, unmatched_group, config=config)
653
+ else:
654
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
565
655
 
566
656
 
567
- def _apply_group_offloading_leaf_level(
568
- module: torch.nn.Module,
569
- offload_device: torch.device,
570
- onload_device: torch.device,
571
- non_blocking: bool,
572
- stream: Optional[torch.cuda.Stream] = None,
573
- record_stream: Optional[bool] = False,
574
- low_cpu_mem_usage: bool = False,
575
- ) -> None:
657
+ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
576
658
  r"""
577
659
  This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
578
660
  requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
579
661
  synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
580
662
  reduce memory usage without any performance degradation.
581
-
582
- Args:
583
- module (`torch.nn.Module`):
584
- The module to which group offloading is applied.
585
- offload_device (`torch.device`):
586
- The device to which the group of modules are offloaded. This should typically be the CPU.
587
- onload_device (`torch.device`):
588
- The device to which the group of modules are onloaded.
589
- non_blocking (`bool`):
590
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
591
- and data transfer.
592
- stream (`torch.cuda.Stream`, *optional*):
593
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
594
- for overlapping computation and data transfer.
595
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
596
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
597
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
598
- details.
599
- low_cpu_mem_usage (`bool`, defaults to `False`):
600
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
601
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
602
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
603
663
  """
604
-
605
664
  # Create module groups for leaf modules and apply group offloading hooks
606
665
  modules_with_group_offloading = set()
607
666
  for name, submodule in module.named_modules():
608
- if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
667
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
609
668
  continue
610
669
  group = ModuleGroup(
611
670
  modules=[submodule],
612
- offload_device=offload_device,
613
- onload_device=onload_device,
671
+ offload_device=config.offload_device,
672
+ onload_device=config.onload_device,
673
+ offload_to_disk_path=config.offload_to_disk_path,
614
674
  offload_leader=submodule,
615
675
  onload_leader=submodule,
616
- non_blocking=non_blocking,
617
- stream=stream,
618
- record_stream=record_stream,
619
- low_cpu_mem_usage=low_cpu_mem_usage,
676
+ non_blocking=config.non_blocking,
677
+ stream=config.stream,
678
+ record_stream=config.record_stream,
679
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
620
680
  onload_self=True,
681
+ group_id=name,
621
682
  )
622
- _apply_group_offloading_hook(submodule, group, None)
683
+ _apply_group_offloading_hook(submodule, group, config=config)
623
684
  modules_with_group_offloading.add(name)
624
685
 
625
686
  # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -650,31 +711,33 @@ def _apply_group_offloading_leaf_level(
650
711
  parameters = parent_to_parameters.get(name, [])
651
712
  buffers = parent_to_buffers.get(name, [])
652
713
  parent_module = module_dict[name]
653
- assert getattr(parent_module, "_diffusers_hook", None) is None
654
714
  group = ModuleGroup(
655
715
  modules=[],
656
- offload_device=offload_device,
657
- onload_device=onload_device,
716
+ offload_device=config.offload_device,
717
+ onload_device=config.onload_device,
658
718
  offload_leader=parent_module,
659
719
  onload_leader=parent_module,
720
+ offload_to_disk_path=config.offload_to_disk_path,
660
721
  parameters=parameters,
661
722
  buffers=buffers,
662
- non_blocking=non_blocking,
663
- stream=stream,
664
- record_stream=record_stream,
665
- low_cpu_mem_usage=low_cpu_mem_usage,
723
+ non_blocking=config.non_blocking,
724
+ stream=config.stream,
725
+ record_stream=config.record_stream,
726
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
666
727
  onload_self=True,
728
+ group_id=name,
667
729
  )
668
- _apply_group_offloading_hook(parent_module, group, None)
730
+ _apply_group_offloading_hook(parent_module, group, config=config)
669
731
 
670
- if stream is not None:
732
+ if config.stream is not None:
671
733
  # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
672
734
  # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
673
735
  # execution order and apply prefetching in the correct order.
674
736
  unmatched_group = ModuleGroup(
675
737
  modules=[],
676
- offload_device=offload_device,
677
- onload_device=onload_device,
738
+ offload_device=config.offload_device,
739
+ onload_device=config.onload_device,
740
+ offload_to_disk_path=config.offload_to_disk_path,
678
741
  offload_leader=module,
679
742
  onload_leader=module,
680
743
  parameters=None,
@@ -682,37 +745,40 @@ def _apply_group_offloading_leaf_level(
682
745
  non_blocking=False,
683
746
  stream=None,
684
747
  record_stream=False,
685
- low_cpu_mem_usage=low_cpu_mem_usage,
748
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
686
749
  onload_self=True,
750
+ group_id=_GROUP_ID_LAZY_LEAF,
687
751
  )
688
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
752
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
689
753
 
690
754
 
691
755
  def _apply_group_offloading_hook(
692
756
  module: torch.nn.Module,
693
757
  group: ModuleGroup,
694
- next_group: Optional[ModuleGroup] = None,
758
+ *,
759
+ config: GroupOffloadingConfig,
695
760
  ) -> None:
696
761
  registry = HookRegistry.check_if_exists_or_initialize(module)
697
762
 
698
763
  # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
699
764
  # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
700
765
  if registry.get_hook(_GROUP_OFFLOADING) is None:
701
- hook = GroupOffloadingHook(group, next_group)
766
+ hook = GroupOffloadingHook(group, config=config)
702
767
  registry.register_hook(hook, _GROUP_OFFLOADING)
703
768
 
704
769
 
705
770
  def _apply_lazy_group_offloading_hook(
706
771
  module: torch.nn.Module,
707
772
  group: ModuleGroup,
708
- next_group: Optional[ModuleGroup] = None,
773
+ *,
774
+ config: GroupOffloadingConfig,
709
775
  ) -> None:
710
776
  registry = HookRegistry.check_if_exists_or_initialize(module)
711
777
 
712
778
  # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
713
779
  # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
714
780
  if registry.get_hook(_GROUP_OFFLOADING) is None:
715
- hook = GroupOffloadingHook(group, next_group)
781
+ hook = GroupOffloadingHook(group, config=config)
716
782
  registry.register_hook(hook, _GROUP_OFFLOADING)
717
783
 
718
784
  lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -779,15 +845,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
779
845
  )
780
846
 
781
847
 
782
- def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
848
+ def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
783
849
  for submodule in module.modules():
784
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
785
- return True
786
- return False
850
+ if hasattr(submodule, "_diffusers_hook"):
851
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
852
+ if group_offloading_hook is not None:
853
+ return group_offloading_hook
854
+ return None
855
+
856
+
857
+ def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
858
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
859
+ return top_level_group_offload_hook is not None
787
860
 
788
861
 
789
862
  def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
790
- for submodule in module.modules():
791
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
792
- return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
863
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
864
+ if top_level_group_offload_hook is not None:
865
+ return top_level_group_offload_hook.config.onload_device
793
866
  raise ValueError("Group offloading is not enabled for the provided module.")
867
+
868
+
869
+ def _compute_group_hash(group_id):
870
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
871
+ # first 16 characters for a reasonably short but unique name
872
+ return hashed_id[:16]
873
+
874
+
875
+ def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
876
+ r"""
877
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
878
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
879
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
880
+
881
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
882
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
883
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
884
+
885
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
886
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
887
+ """
888
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
889
+
890
+ if top_level_group_offload_hook is None:
891
+ return
892
+
893
+ registry = HookRegistry.check_if_exists_or_initialize(module)
894
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
895
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
896
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
897
+
898
+ _apply_group_offloading(module, top_level_group_offload_hook.config)