diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,12 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import hashlib
|
16
|
+
import os
|
15
17
|
from contextlib import contextmanager, nullcontext
|
16
|
-
from
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from enum import Enum
|
20
|
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
17
21
|
|
22
|
+
import safetensors.torch
|
18
23
|
import torch
|
19
24
|
|
20
25
|
from ..utils import get_logger, is_accelerate_available
|
26
|
+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
21
27
|
from .hooks import HookRegistry, ModelHook
|
22
28
|
|
23
29
|
|
@@ -33,17 +39,28 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
33
39
|
_GROUP_OFFLOADING = "group_offloading"
|
34
40
|
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
35
41
|
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
36
|
-
|
37
|
-
_SUPPORTED_PYTORCH_LAYERS = (
|
38
|
-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
39
|
-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
40
|
-
torch.nn.Linear,
|
41
|
-
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
42
|
-
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
43
|
-
)
|
42
|
+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
|
44
43
|
# fmt: on
|
45
44
|
|
46
45
|
|
46
|
+
class GroupOffloadingType(str, Enum):
|
47
|
+
BLOCK_LEVEL = "block_level"
|
48
|
+
LEAF_LEVEL = "leaf_level"
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class GroupOffloadingConfig:
|
53
|
+
onload_device: torch.device
|
54
|
+
offload_device: torch.device
|
55
|
+
offload_type: GroupOffloadingType
|
56
|
+
non_blocking: bool
|
57
|
+
record_stream: bool
|
58
|
+
low_cpu_mem_usage: bool
|
59
|
+
num_blocks_per_group: Optional[int] = None
|
60
|
+
offload_to_disk_path: Optional[str] = None
|
61
|
+
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
62
|
+
|
63
|
+
|
47
64
|
class ModuleGroup:
|
48
65
|
def __init__(
|
49
66
|
self,
|
@@ -55,10 +72,12 @@ class ModuleGroup:
|
|
55
72
|
parameters: Optional[List[torch.nn.Parameter]] = None,
|
56
73
|
buffers: Optional[List[torch.Tensor]] = None,
|
57
74
|
non_blocking: bool = False,
|
58
|
-
stream:
|
75
|
+
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
59
76
|
record_stream: Optional[bool] = False,
|
60
|
-
low_cpu_mem_usage=False,
|
77
|
+
low_cpu_mem_usage: bool = False,
|
61
78
|
onload_self: bool = True,
|
79
|
+
offload_to_disk_path: Optional[str] = None,
|
80
|
+
group_id: Optional[int] = None,
|
62
81
|
) -> None:
|
63
82
|
self.modules = modules
|
64
83
|
self.offload_device = offload_device
|
@@ -72,10 +91,35 @@ class ModuleGroup:
|
|
72
91
|
self.record_stream = record_stream
|
73
92
|
self.onload_self = onload_self
|
74
93
|
self.low_cpu_mem_usage = low_cpu_mem_usage
|
75
|
-
self.cpu_param_dict = self._init_cpu_param_dict()
|
76
94
|
|
77
|
-
|
78
|
-
|
95
|
+
self.offload_to_disk_path = offload_to_disk_path
|
96
|
+
self._is_offloaded_to_disk = False
|
97
|
+
|
98
|
+
if self.offload_to_disk_path is not None:
|
99
|
+
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
100
|
+
self.group_id = group_id if group_id is not None else str(id(self))
|
101
|
+
short_hash = _compute_group_hash(self.group_id)
|
102
|
+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
|
103
|
+
|
104
|
+
all_tensors = []
|
105
|
+
for module in self.modules:
|
106
|
+
all_tensors.extend(list(module.parameters()))
|
107
|
+
all_tensors.extend(list(module.buffers()))
|
108
|
+
all_tensors.extend(self.parameters)
|
109
|
+
all_tensors.extend(self.buffers)
|
110
|
+
all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
|
111
|
+
|
112
|
+
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
113
|
+
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
114
|
+
self.cpu_param_dict = {}
|
115
|
+
else:
|
116
|
+
self.cpu_param_dict = self._init_cpu_param_dict()
|
117
|
+
|
118
|
+
self._torch_accelerator_module = (
|
119
|
+
getattr(torch, torch.accelerator.current_accelerator().type)
|
120
|
+
if hasattr(torch, "accelerator")
|
121
|
+
else torch.cuda
|
122
|
+
)
|
79
123
|
|
80
124
|
def _init_cpu_param_dict(self):
|
81
125
|
cpu_param_dict = {}
|
@@ -100,71 +144,100 @@ class ModuleGroup:
|
|
100
144
|
|
101
145
|
@contextmanager
|
102
146
|
def _pinned_memory_tensors(self):
|
103
|
-
pinned_dict = {}
|
104
147
|
try:
|
105
|
-
|
106
|
-
if not tensor.is_pinned()
|
107
|
-
|
108
|
-
|
109
|
-
pinned_dict[param] = tensor
|
110
|
-
|
148
|
+
pinned_dict = {
|
149
|
+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
150
|
+
for param, tensor in self.cpu_param_dict.items()
|
151
|
+
}
|
111
152
|
yield pinned_dict
|
112
|
-
|
113
153
|
finally:
|
114
154
|
pinned_dict = None
|
115
155
|
|
116
|
-
def
|
117
|
-
|
118
|
-
|
119
|
-
|
156
|
+
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
157
|
+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
158
|
+
if self.record_stream:
|
159
|
+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
|
160
|
+
|
161
|
+
def _process_tensors_from_modules(self, pinned_memory=None):
|
162
|
+
for group_module in self.modules:
|
163
|
+
for param in group_module.parameters():
|
164
|
+
source = pinned_memory[param] if pinned_memory else param.data
|
165
|
+
self._transfer_tensor_to_device(param, source)
|
166
|
+
for buffer in group_module.buffers():
|
167
|
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
168
|
+
self._transfer_tensor_to_device(buffer, source)
|
169
|
+
|
170
|
+
for param in self.parameters:
|
171
|
+
source = pinned_memory[param] if pinned_memory else param.data
|
172
|
+
self._transfer_tensor_to_device(param, source)
|
173
|
+
|
174
|
+
for buffer in self.buffers:
|
175
|
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
176
|
+
self._transfer_tensor_to_device(buffer, source)
|
120
177
|
|
178
|
+
def _onload_from_disk(self):
|
121
179
|
if self.stream is not None:
|
122
180
|
# Wait for previous Host->Device transfer to complete
|
123
181
|
self.stream.synchronize()
|
124
182
|
|
183
|
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
184
|
+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
185
|
+
|
125
186
|
with context:
|
126
|
-
if
|
127
|
-
|
128
|
-
|
129
|
-
for param in group_module.parameters():
|
130
|
-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
131
|
-
if self.record_stream:
|
132
|
-
param.data.record_stream(current_stream)
|
133
|
-
for buffer in group_module.buffers():
|
134
|
-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
135
|
-
if self.record_stream:
|
136
|
-
buffer.data.record_stream(current_stream)
|
137
|
-
|
138
|
-
for param in self.parameters:
|
139
|
-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
140
|
-
if self.record_stream:
|
141
|
-
param.data.record_stream(current_stream)
|
142
|
-
|
143
|
-
for buffer in self.buffers:
|
144
|
-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
145
|
-
if self.record_stream:
|
146
|
-
buffer.data.record_stream(current_stream)
|
187
|
+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
188
|
+
device = str(self.onload_device) if self.stream is None else "cpu"
|
189
|
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
147
190
|
|
191
|
+
if self.stream is not None:
|
192
|
+
for key, tensor_obj in self.key_to_tensor.items():
|
193
|
+
pinned_tensor = loaded_tensors[key].pin_memory()
|
194
|
+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
195
|
+
if self.record_stream:
|
196
|
+
tensor_obj.data.record_stream(current_stream)
|
148
197
|
else:
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
for param in self.parameters:
|
156
|
-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
198
|
+
onload_device = (
|
199
|
+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
200
|
+
)
|
201
|
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
202
|
+
for key, tensor_obj in self.key_to_tensor.items():
|
203
|
+
tensor_obj.data = loaded_tensors[key]
|
157
204
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
205
|
+
def _onload_from_memory(self):
|
206
|
+
if self.stream is not None:
|
207
|
+
# Wait for previous Host->Device transfer to complete
|
208
|
+
self.stream.synchronize()
|
162
209
|
|
163
|
-
|
164
|
-
|
210
|
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
211
|
+
with context:
|
212
|
+
if self.stream is not None:
|
213
|
+
with self._pinned_memory_tensors() as pinned_memory:
|
214
|
+
self._process_tensors_from_modules(pinned_memory)
|
215
|
+
else:
|
216
|
+
self._process_tensors_from_modules(None)
|
217
|
+
|
218
|
+
def _offload_to_disk(self):
|
219
|
+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
220
|
+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
221
|
+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
222
|
+
# we perform a write.
|
223
|
+
# Check if the file has been saved in this session or if it already exists on disk.
|
224
|
+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
225
|
+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
226
|
+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
227
|
+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
228
|
+
|
229
|
+
# The group is now considered offloaded to disk for the rest of the session.
|
230
|
+
self._is_offloaded_to_disk = True
|
231
|
+
|
232
|
+
# We do this to free up the RAM which is still holding the up tensor data.
|
233
|
+
for tensor_obj in self.tensor_to_key.keys():
|
234
|
+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
235
|
+
|
236
|
+
def _offload_to_memory(self):
|
165
237
|
if self.stream is not None:
|
166
238
|
if not self.record_stream:
|
167
|
-
|
239
|
+
self._torch_accelerator_module.current_stream().synchronize()
|
240
|
+
|
168
241
|
for group_module in self.modules:
|
169
242
|
for param in group_module.parameters():
|
170
243
|
param.data = self.cpu_param_dict[param]
|
@@ -172,14 +245,29 @@ class ModuleGroup:
|
|
172
245
|
param.data = self.cpu_param_dict[param]
|
173
246
|
for buffer in self.buffers:
|
174
247
|
buffer.data = self.cpu_param_dict[buffer]
|
175
|
-
|
176
248
|
else:
|
177
249
|
for group_module in self.modules:
|
178
|
-
group_module.to(self.offload_device, non_blocking=
|
250
|
+
group_module.to(self.offload_device, non_blocking=False)
|
179
251
|
for param in self.parameters:
|
180
|
-
param.data = param.data.to(self.offload_device, non_blocking=
|
252
|
+
param.data = param.data.to(self.offload_device, non_blocking=False)
|
181
253
|
for buffer in self.buffers:
|
182
|
-
buffer.data = buffer.data.to(self.offload_device, non_blocking=
|
254
|
+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
255
|
+
|
256
|
+
@torch.compiler.disable()
|
257
|
+
def onload_(self):
|
258
|
+
r"""Onloads the group of parameters to the onload_device."""
|
259
|
+
if self.offload_to_disk_path is not None:
|
260
|
+
self._onload_from_disk()
|
261
|
+
else:
|
262
|
+
self._onload_from_memory()
|
263
|
+
|
264
|
+
@torch.compiler.disable()
|
265
|
+
def offload_(self):
|
266
|
+
r"""Offloads the group of parameters to the offload_device."""
|
267
|
+
if self.offload_to_disk_path:
|
268
|
+
self._offload_to_disk()
|
269
|
+
else:
|
270
|
+
self._offload_to_memory()
|
183
271
|
|
184
272
|
|
185
273
|
class GroupOffloadingHook(ModelHook):
|
@@ -192,13 +280,10 @@ class GroupOffloadingHook(ModelHook):
|
|
192
280
|
|
193
281
|
_is_stateful = False
|
194
282
|
|
195
|
-
def __init__(
|
196
|
-
self,
|
197
|
-
group: ModuleGroup,
|
198
|
-
next_group: Optional[ModuleGroup] = None,
|
199
|
-
) -> None:
|
283
|
+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
200
284
|
self.group = group
|
201
|
-
self.next_group =
|
285
|
+
self.next_group: Optional[ModuleGroup] = None
|
286
|
+
self.config = config
|
202
287
|
|
203
288
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
204
289
|
if self.group.offload_leader == module:
|
@@ -217,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
|
|
217
302
|
if self.group.onload_leader == module:
|
218
303
|
if self.group.onload_self:
|
219
304
|
self.group.onload_()
|
220
|
-
|
305
|
+
|
306
|
+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
307
|
+
if should_onload_next_group:
|
221
308
|
self.next_group.onload_()
|
222
309
|
|
310
|
+
should_synchronize = (
|
311
|
+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
312
|
+
)
|
313
|
+
if should_synchronize:
|
314
|
+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
315
|
+
# previous group. We need to synchronize the side stream to ensure parameters
|
316
|
+
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
317
|
+
# weights will be used in the computation, leading to incorrect results
|
318
|
+
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
319
|
+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
320
|
+
self.group.stream.synchronize()
|
321
|
+
|
223
322
|
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
224
323
|
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
225
324
|
return args, kwargs
|
@@ -232,7 +331,7 @@ class GroupOffloadingHook(ModelHook):
|
|
232
331
|
|
233
332
|
class LazyPrefetchGroupOffloadingHook(ModelHook):
|
234
333
|
r"""
|
235
|
-
A hook, used in
|
334
|
+
A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
|
236
335
|
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
|
237
336
|
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
|
238
337
|
prefetching groups in the correct order.
|
@@ -247,7 +346,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
247
346
|
def initialize_hook(self, module):
|
248
347
|
def make_execution_order_update_callback(current_name, current_submodule):
|
249
348
|
def callback():
|
250
|
-
|
349
|
+
if not torch.compiler.is_compiling():
|
350
|
+
logger.debug(f"Adding {current_name} to the execution order")
|
251
351
|
self.execution_order.append((current_name, current_submodule))
|
252
352
|
|
253
353
|
return callback
|
@@ -284,12 +384,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
284
384
|
# if the missing layers end up being executed in the future.
|
285
385
|
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
286
386
|
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
387
|
+
if not torch.compiler.is_compiling():
|
388
|
+
logger.warning(
|
389
|
+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
390
|
+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
391
|
+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
392
|
+
f"{unexecuted_layers=}"
|
393
|
+
)
|
293
394
|
|
294
395
|
# Remove the layer execution tracker hooks from the submodules
|
295
396
|
base_module_registry = module._diffusers_hook
|
@@ -317,7 +418,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
317
418
|
for i in range(num_executed - 1):
|
318
419
|
name1, _ = self.execution_order[i]
|
319
420
|
name2, _ = self.execution_order[i + 1]
|
320
|
-
|
421
|
+
if not torch.compiler.is_compiling():
|
422
|
+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
321
423
|
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
322
424
|
group_offloading_hooks[i].next_group.onload_self = False
|
323
425
|
|
@@ -342,14 +444,15 @@ class LayerExecutionTrackerHook(ModelHook):
|
|
342
444
|
|
343
445
|
def apply_group_offloading(
|
344
446
|
module: torch.nn.Module,
|
345
|
-
onload_device: torch.device,
|
346
|
-
offload_device: torch.device = torch.device("cpu"),
|
347
|
-
offload_type: str = "block_level",
|
447
|
+
onload_device: Union[str, torch.device],
|
448
|
+
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
449
|
+
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
348
450
|
num_blocks_per_group: Optional[int] = None,
|
349
451
|
non_blocking: bool = False,
|
350
452
|
use_stream: bool = False,
|
351
453
|
record_stream: bool = False,
|
352
454
|
low_cpu_mem_usage: bool = False,
|
455
|
+
offload_to_disk_path: Optional[str] = None,
|
353
456
|
) -> None:
|
354
457
|
r"""
|
355
458
|
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
|
@@ -385,9 +488,12 @@ def apply_group_offloading(
|
|
385
488
|
The device to which the group of modules are onloaded.
|
386
489
|
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
387
490
|
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
388
|
-
offload_type (`str`, defaults to "block_level"):
|
491
|
+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
389
492
|
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
390
493
|
"block_level".
|
494
|
+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
495
|
+
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
496
|
+
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
391
497
|
num_blocks_per_group (`int`, *optional*):
|
392
498
|
The number of blocks per group when using offload_type="block_level". This is required when using
|
393
499
|
offload_type="block_level".
|
@@ -425,80 +531,61 @@ def apply_group_offloading(
|
|
425
531
|
```
|
426
532
|
"""
|
427
533
|
|
534
|
+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
535
|
+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
536
|
+
offload_type = GroupOffloadingType(offload_type)
|
537
|
+
|
428
538
|
stream = None
|
429
539
|
if use_stream:
|
430
540
|
if torch.cuda.is_available():
|
431
541
|
stream = torch.cuda.Stream()
|
542
|
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
543
|
+
stream = torch.Stream()
|
432
544
|
else:
|
433
|
-
raise ValueError("Using streams for data transfer requires a CUDA device.")
|
545
|
+
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
546
|
+
|
547
|
+
if not use_stream and record_stream:
|
548
|
+
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
549
|
+
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
550
|
+
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
434
551
|
|
435
552
|
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
436
553
|
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
non_blocking=non_blocking,
|
457
|
-
stream=stream,
|
458
|
-
record_stream=record_stream,
|
459
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
460
|
-
)
|
554
|
+
config = GroupOffloadingConfig(
|
555
|
+
onload_device=onload_device,
|
556
|
+
offload_device=offload_device,
|
557
|
+
offload_type=offload_type,
|
558
|
+
num_blocks_per_group=num_blocks_per_group,
|
559
|
+
non_blocking=non_blocking,
|
560
|
+
stream=stream,
|
561
|
+
record_stream=record_stream,
|
562
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
563
|
+
offload_to_disk_path=offload_to_disk_path,
|
564
|
+
)
|
565
|
+
_apply_group_offloading(module, config)
|
566
|
+
|
567
|
+
|
568
|
+
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
569
|
+
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
570
|
+
_apply_group_offloading_block_level(module, config)
|
571
|
+
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
572
|
+
_apply_group_offloading_leaf_level(module, config)
|
461
573
|
else:
|
462
|
-
|
574
|
+
assert False
|
463
575
|
|
464
576
|
|
465
|
-
def _apply_group_offloading_block_level(
|
466
|
-
module: torch.nn.Module,
|
467
|
-
num_blocks_per_group: int,
|
468
|
-
offload_device: torch.device,
|
469
|
-
onload_device: torch.device,
|
470
|
-
non_blocking: bool,
|
471
|
-
stream: Optional[torch.cuda.Stream] = None,
|
472
|
-
record_stream: Optional[bool] = False,
|
473
|
-
low_cpu_mem_usage: bool = False,
|
474
|
-
) -> None:
|
577
|
+
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
475
578
|
r"""
|
476
579
|
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
477
580
|
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
478
|
-
|
479
|
-
Args:
|
480
|
-
module (`torch.nn.Module`):
|
481
|
-
The module to which group offloading is applied.
|
482
|
-
offload_device (`torch.device`):
|
483
|
-
The device to which the group of modules are offloaded. This should typically be the CPU.
|
484
|
-
onload_device (`torch.device`):
|
485
|
-
The device to which the group of modules are onloaded.
|
486
|
-
non_blocking (`bool`):
|
487
|
-
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
488
|
-
and data transfer.
|
489
|
-
stream (`torch.cuda.Stream`, *optional*):
|
490
|
-
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
491
|
-
for overlapping computation and data transfer.
|
492
|
-
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
493
|
-
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
494
|
-
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
495
|
-
details.
|
496
|
-
low_cpu_mem_usage (`bool`, defaults to `False`):
|
497
|
-
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
498
|
-
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
499
|
-
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
500
581
|
"""
|
501
582
|
|
583
|
+
if config.stream is not None and config.num_blocks_per_group != 1:
|
584
|
+
logger.warning(
|
585
|
+
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
586
|
+
)
|
587
|
+
config.num_blocks_per_group = 1
|
588
|
+
|
502
589
|
# Create module groups for ModuleList and Sequential blocks
|
503
590
|
modules_with_group_offloading = set()
|
504
591
|
unmatched_modules = []
|
@@ -509,19 +596,22 @@ def _apply_group_offloading_block_level(
|
|
509
596
|
modules_with_group_offloading.add(name)
|
510
597
|
continue
|
511
598
|
|
512
|
-
for i in range(0, len(submodule), num_blocks_per_group):
|
513
|
-
current_modules = submodule[i : i + num_blocks_per_group]
|
599
|
+
for i in range(0, len(submodule), config.num_blocks_per_group):
|
600
|
+
current_modules = submodule[i : i + config.num_blocks_per_group]
|
601
|
+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
514
602
|
group = ModuleGroup(
|
515
603
|
modules=current_modules,
|
516
|
-
offload_device=offload_device,
|
517
|
-
onload_device=onload_device,
|
604
|
+
offload_device=config.offload_device,
|
605
|
+
onload_device=config.onload_device,
|
606
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
518
607
|
offload_leader=current_modules[-1],
|
519
608
|
onload_leader=current_modules[0],
|
520
|
-
non_blocking=non_blocking,
|
521
|
-
stream=stream,
|
522
|
-
record_stream=record_stream,
|
523
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
524
|
-
onload_self=
|
609
|
+
non_blocking=config.non_blocking,
|
610
|
+
stream=config.stream,
|
611
|
+
record_stream=config.record_stream,
|
612
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
613
|
+
onload_self=True,
|
614
|
+
group_id=group_id,
|
525
615
|
)
|
526
616
|
matched_module_groups.append(group)
|
527
617
|
for j in range(i, i + len(current_modules)):
|
@@ -529,12 +619,8 @@ def _apply_group_offloading_block_level(
|
|
529
619
|
|
530
620
|
# Apply group offloading hooks to the module groups
|
531
621
|
for i, group in enumerate(matched_module_groups):
|
532
|
-
next_group = (
|
533
|
-
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
|
534
|
-
)
|
535
|
-
|
536
622
|
for group_module in group.modules:
|
537
|
-
_apply_group_offloading_hook(group_module, group,
|
623
|
+
_apply_group_offloading_hook(group_module, group, config=config)
|
538
624
|
|
539
625
|
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
540
626
|
# when the forward pass of this module is called. This is because the top-level module is not
|
@@ -549,8 +635,9 @@ def _apply_group_offloading_block_level(
|
|
549
635
|
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
550
636
|
unmatched_group = ModuleGroup(
|
551
637
|
modules=unmatched_modules,
|
552
|
-
offload_device=offload_device,
|
553
|
-
onload_device=onload_device,
|
638
|
+
offload_device=config.offload_device,
|
639
|
+
onload_device=config.onload_device,
|
640
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
554
641
|
offload_leader=module,
|
555
642
|
onload_leader=module,
|
556
643
|
parameters=parameters,
|
@@ -559,67 +646,41 @@ def _apply_group_offloading_block_level(
|
|
559
646
|
stream=None,
|
560
647
|
record_stream=False,
|
561
648
|
onload_self=True,
|
649
|
+
group_id=f"{module.__class__.__name__}_unmatched_group",
|
562
650
|
)
|
563
|
-
|
564
|
-
|
651
|
+
if config.stream is None:
|
652
|
+
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
653
|
+
else:
|
654
|
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
565
655
|
|
566
656
|
|
567
|
-
def _apply_group_offloading_leaf_level(
|
568
|
-
module: torch.nn.Module,
|
569
|
-
offload_device: torch.device,
|
570
|
-
onload_device: torch.device,
|
571
|
-
non_blocking: bool,
|
572
|
-
stream: Optional[torch.cuda.Stream] = None,
|
573
|
-
record_stream: Optional[bool] = False,
|
574
|
-
low_cpu_mem_usage: bool = False,
|
575
|
-
) -> None:
|
657
|
+
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
576
658
|
r"""
|
577
659
|
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
578
660
|
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
579
661
|
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
580
662
|
reduce memory usage without any performance degradation.
|
581
|
-
|
582
|
-
Args:
|
583
|
-
module (`torch.nn.Module`):
|
584
|
-
The module to which group offloading is applied.
|
585
|
-
offload_device (`torch.device`):
|
586
|
-
The device to which the group of modules are offloaded. This should typically be the CPU.
|
587
|
-
onload_device (`torch.device`):
|
588
|
-
The device to which the group of modules are onloaded.
|
589
|
-
non_blocking (`bool`):
|
590
|
-
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
591
|
-
and data transfer.
|
592
|
-
stream (`torch.cuda.Stream`, *optional*):
|
593
|
-
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
594
|
-
for overlapping computation and data transfer.
|
595
|
-
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
596
|
-
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
597
|
-
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
598
|
-
details.
|
599
|
-
low_cpu_mem_usage (`bool`, defaults to `False`):
|
600
|
-
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
601
|
-
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
602
|
-
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
603
663
|
"""
|
604
|
-
|
605
664
|
# Create module groups for leaf modules and apply group offloading hooks
|
606
665
|
modules_with_group_offloading = set()
|
607
666
|
for name, submodule in module.named_modules():
|
608
|
-
if not isinstance(submodule,
|
667
|
+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
609
668
|
continue
|
610
669
|
group = ModuleGroup(
|
611
670
|
modules=[submodule],
|
612
|
-
offload_device=offload_device,
|
613
|
-
onload_device=onload_device,
|
671
|
+
offload_device=config.offload_device,
|
672
|
+
onload_device=config.onload_device,
|
673
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
614
674
|
offload_leader=submodule,
|
615
675
|
onload_leader=submodule,
|
616
|
-
non_blocking=non_blocking,
|
617
|
-
stream=stream,
|
618
|
-
record_stream=record_stream,
|
619
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
676
|
+
non_blocking=config.non_blocking,
|
677
|
+
stream=config.stream,
|
678
|
+
record_stream=config.record_stream,
|
679
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
620
680
|
onload_self=True,
|
681
|
+
group_id=name,
|
621
682
|
)
|
622
|
-
_apply_group_offloading_hook(submodule, group,
|
683
|
+
_apply_group_offloading_hook(submodule, group, config=config)
|
623
684
|
modules_with_group_offloading.add(name)
|
624
685
|
|
625
686
|
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
@@ -650,31 +711,33 @@ def _apply_group_offloading_leaf_level(
|
|
650
711
|
parameters = parent_to_parameters.get(name, [])
|
651
712
|
buffers = parent_to_buffers.get(name, [])
|
652
713
|
parent_module = module_dict[name]
|
653
|
-
assert getattr(parent_module, "_diffusers_hook", None) is None
|
654
714
|
group = ModuleGroup(
|
655
715
|
modules=[],
|
656
|
-
offload_device=offload_device,
|
657
|
-
onload_device=onload_device,
|
716
|
+
offload_device=config.offload_device,
|
717
|
+
onload_device=config.onload_device,
|
658
718
|
offload_leader=parent_module,
|
659
719
|
onload_leader=parent_module,
|
720
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
660
721
|
parameters=parameters,
|
661
722
|
buffers=buffers,
|
662
|
-
non_blocking=non_blocking,
|
663
|
-
stream=stream,
|
664
|
-
record_stream=record_stream,
|
665
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
723
|
+
non_blocking=config.non_blocking,
|
724
|
+
stream=config.stream,
|
725
|
+
record_stream=config.record_stream,
|
726
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
666
727
|
onload_self=True,
|
728
|
+
group_id=name,
|
667
729
|
)
|
668
|
-
_apply_group_offloading_hook(parent_module, group,
|
730
|
+
_apply_group_offloading_hook(parent_module, group, config=config)
|
669
731
|
|
670
|
-
if stream is not None:
|
732
|
+
if config.stream is not None:
|
671
733
|
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
672
734
|
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
673
735
|
# execution order and apply prefetching in the correct order.
|
674
736
|
unmatched_group = ModuleGroup(
|
675
737
|
modules=[],
|
676
|
-
offload_device=offload_device,
|
677
|
-
onload_device=onload_device,
|
738
|
+
offload_device=config.offload_device,
|
739
|
+
onload_device=config.onload_device,
|
740
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
678
741
|
offload_leader=module,
|
679
742
|
onload_leader=module,
|
680
743
|
parameters=None,
|
@@ -682,37 +745,40 @@ def _apply_group_offloading_leaf_level(
|
|
682
745
|
non_blocking=False,
|
683
746
|
stream=None,
|
684
747
|
record_stream=False,
|
685
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
748
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
686
749
|
onload_self=True,
|
750
|
+
group_id=_GROUP_ID_LAZY_LEAF,
|
687
751
|
)
|
688
|
-
_apply_lazy_group_offloading_hook(module, unmatched_group,
|
752
|
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
689
753
|
|
690
754
|
|
691
755
|
def _apply_group_offloading_hook(
|
692
756
|
module: torch.nn.Module,
|
693
757
|
group: ModuleGroup,
|
694
|
-
|
758
|
+
*,
|
759
|
+
config: GroupOffloadingConfig,
|
695
760
|
) -> None:
|
696
761
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
697
762
|
|
698
763
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
699
764
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
700
765
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
701
|
-
hook = GroupOffloadingHook(group,
|
766
|
+
hook = GroupOffloadingHook(group, config=config)
|
702
767
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
703
768
|
|
704
769
|
|
705
770
|
def _apply_lazy_group_offloading_hook(
|
706
771
|
module: torch.nn.Module,
|
707
772
|
group: ModuleGroup,
|
708
|
-
|
773
|
+
*,
|
774
|
+
config: GroupOffloadingConfig,
|
709
775
|
) -> None:
|
710
776
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
711
777
|
|
712
778
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
713
779
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
714
780
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
715
|
-
hook = GroupOffloadingHook(group,
|
781
|
+
hook = GroupOffloadingHook(group, config=config)
|
716
782
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
717
783
|
|
718
784
|
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
@@ -779,15 +845,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
|
779
845
|
)
|
780
846
|
|
781
847
|
|
782
|
-
def
|
848
|
+
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
783
849
|
for submodule in module.modules():
|
784
|
-
if hasattr(submodule, "_diffusers_hook")
|
785
|
-
|
786
|
-
|
850
|
+
if hasattr(submodule, "_diffusers_hook"):
|
851
|
+
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
852
|
+
if group_offloading_hook is not None:
|
853
|
+
return group_offloading_hook
|
854
|
+
return None
|
855
|
+
|
856
|
+
|
857
|
+
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
858
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
859
|
+
return top_level_group_offload_hook is not None
|
787
860
|
|
788
861
|
|
789
862
|
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
790
|
-
|
791
|
-
|
792
|
-
|
863
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
864
|
+
if top_level_group_offload_hook is not None:
|
865
|
+
return top_level_group_offload_hook.config.onload_device
|
793
866
|
raise ValueError("Group offloading is not enabled for the provided module.")
|
867
|
+
|
868
|
+
|
869
|
+
def _compute_group_hash(group_id):
|
870
|
+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
|
871
|
+
# first 16 characters for a reasonably short but unique name
|
872
|
+
return hashed_id[:16]
|
873
|
+
|
874
|
+
|
875
|
+
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
876
|
+
r"""
|
877
|
+
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
878
|
+
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
879
|
+
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
880
|
+
|
881
|
+
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
882
|
+
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
883
|
+
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
884
|
+
|
885
|
+
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
886
|
+
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
887
|
+
"""
|
888
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
889
|
+
|
890
|
+
if top_level_group_offload_hook is None:
|
891
|
+
return
|
892
|
+
|
893
|
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
894
|
+
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
895
|
+
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
896
|
+
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
897
|
+
|
898
|
+
_apply_group_offloading(module, top_level_group_offload_hook.config)
|