diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
-
import torch.nn.functional as F
|
22
21
|
|
23
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
23
|
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
@@ -32,9 +31,10 @@ from ...models.attention_processor import (
|
|
32
31
|
)
|
33
32
|
from ...models.modeling_utils import ModelMixin
|
34
33
|
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
35
|
-
from ...utils import USE_PEFT_BACKEND,
|
34
|
+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
36
35
|
from ...utils.import_utils import is_torch_npu_available
|
37
36
|
from ...utils.torch_utils import maybe_allow_in_graph
|
37
|
+
from ..cache_utils import CacheMixin
|
38
38
|
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
39
39
|
from ..modeling_outputs import Transformer2DModelOutput
|
40
40
|
|
@@ -44,20 +44,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
44
44
|
|
45
45
|
@maybe_allow_in_graph
|
46
46
|
class FluxSingleTransformerBlock(nn.Module):
|
47
|
-
|
48
|
-
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
49
|
-
|
50
|
-
Reference: https://arxiv.org/abs/2403.03206
|
51
|
-
|
52
|
-
Parameters:
|
53
|
-
dim (`int`): The number of channels in the input and output.
|
54
|
-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
55
|
-
attention_head_dim (`int`): The number of channels in each head.
|
56
|
-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
57
|
-
processing of `context` conditions.
|
58
|
-
"""
|
59
|
-
|
60
|
-
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
47
|
+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
61
48
|
super().__init__()
|
62
49
|
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
63
50
|
|
@@ -67,9 +54,15 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
67
54
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
68
55
|
|
69
56
|
if is_torch_npu_available():
|
57
|
+
deprecation_message = (
|
58
|
+
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
59
|
+
"should be set explicitly using the `set_attn_processor` method."
|
60
|
+
)
|
61
|
+
deprecate("npu_processor", "0.34.0", deprecation_message)
|
70
62
|
processor = FluxAttnProcessor2_0_NPU()
|
71
63
|
else:
|
72
64
|
processor = FluxAttnProcessor2_0()
|
65
|
+
|
73
66
|
self.attn = Attention(
|
74
67
|
query_dim=dim,
|
75
68
|
cross_attention_dim=None,
|
@@ -85,11 +78,11 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
85
78
|
|
86
79
|
def forward(
|
87
80
|
self,
|
88
|
-
hidden_states: torch.
|
89
|
-
temb: torch.
|
90
|
-
image_rotary_emb=None,
|
91
|
-
joint_attention_kwargs=None,
|
92
|
-
):
|
81
|
+
hidden_states: torch.Tensor,
|
82
|
+
temb: torch.Tensor,
|
83
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
84
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
85
|
+
) -> torch.Tensor:
|
93
86
|
residual = hidden_states
|
94
87
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
95
88
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
@@ -112,32 +105,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
112
105
|
|
113
106
|
@maybe_allow_in_graph
|
114
107
|
class FluxTransformerBlock(nn.Module):
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
Reference: https://arxiv.org/abs/2403.03206
|
119
|
-
|
120
|
-
Parameters:
|
121
|
-
dim (`int`): The number of channels in the input and output.
|
122
|
-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
123
|
-
attention_head_dim (`int`): The number of channels in each head.
|
124
|
-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
125
|
-
processing of `context` conditions.
|
126
|
-
"""
|
127
|
-
|
128
|
-
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
|
108
|
+
def __init__(
|
109
|
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
110
|
+
):
|
129
111
|
super().__init__()
|
130
112
|
|
131
113
|
self.norm1 = AdaLayerNormZero(dim)
|
132
|
-
|
133
114
|
self.norm1_context = AdaLayerNormZero(dim)
|
134
115
|
|
135
|
-
if hasattr(F, "scaled_dot_product_attention"):
|
136
|
-
processor = FluxAttnProcessor2_0()
|
137
|
-
else:
|
138
|
-
raise ValueError(
|
139
|
-
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
140
|
-
)
|
141
116
|
self.attn = Attention(
|
142
117
|
query_dim=dim,
|
143
118
|
cross_attention_dim=None,
|
@@ -147,7 +122,7 @@ class FluxTransformerBlock(nn.Module):
|
|
147
122
|
out_dim=dim,
|
148
123
|
context_pre_only=False,
|
149
124
|
bias=True,
|
150
|
-
processor=
|
125
|
+
processor=FluxAttnProcessor2_0(),
|
151
126
|
qk_norm=qk_norm,
|
152
127
|
eps=eps,
|
153
128
|
)
|
@@ -158,18 +133,14 @@ class FluxTransformerBlock(nn.Module):
|
|
158
133
|
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
159
134
|
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
160
135
|
|
161
|
-
# let chunk size default to None
|
162
|
-
self._chunk_size = None
|
163
|
-
self._chunk_dim = 0
|
164
|
-
|
165
136
|
def forward(
|
166
137
|
self,
|
167
|
-
hidden_states: torch.
|
168
|
-
encoder_hidden_states: torch.
|
169
|
-
temb: torch.
|
170
|
-
image_rotary_emb=None,
|
171
|
-
joint_attention_kwargs=None,
|
172
|
-
):
|
138
|
+
hidden_states: torch.Tensor,
|
139
|
+
encoder_hidden_states: torch.Tensor,
|
140
|
+
temb: torch.Tensor,
|
141
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
142
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
143
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
173
144
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
174
145
|
|
175
146
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
@@ -220,27 +191,42 @@ class FluxTransformerBlock(nn.Module):
|
|
220
191
|
|
221
192
|
|
222
193
|
class FluxTransformer2DModel(
|
223
|
-
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
194
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
224
195
|
):
|
225
196
|
"""
|
226
197
|
The Transformer model introduced in Flux.
|
227
198
|
|
228
199
|
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
229
200
|
|
230
|
-
|
231
|
-
patch_size (`int
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
201
|
+
Args:
|
202
|
+
patch_size (`int`, defaults to `1`):
|
203
|
+
Patch size to turn the input data into small patches.
|
204
|
+
in_channels (`int`, defaults to `64`):
|
205
|
+
The number of channels in the input.
|
206
|
+
out_channels (`int`, *optional*, defaults to `None`):
|
207
|
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
208
|
+
num_layers (`int`, defaults to `19`):
|
209
|
+
The number of layers of dual stream DiT blocks to use.
|
210
|
+
num_single_layers (`int`, defaults to `38`):
|
211
|
+
The number of layers of single stream DiT blocks to use.
|
212
|
+
attention_head_dim (`int`, defaults to `128`):
|
213
|
+
The number of dimensions to use for each attention head.
|
214
|
+
num_attention_heads (`int`, defaults to `24`):
|
215
|
+
The number of attention heads to use.
|
216
|
+
joint_attention_dim (`int`, defaults to `4096`):
|
217
|
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
218
|
+
`encoder_hidden_states`).
|
219
|
+
pooled_projection_dim (`int`, defaults to `768`):
|
220
|
+
The number of dimensions to use for the pooled projection.
|
221
|
+
guidance_embeds (`bool`, defaults to `False`):
|
222
|
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
223
|
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
224
|
+
The dimensions to use for the rotary positional embeddings.
|
240
225
|
"""
|
241
226
|
|
242
227
|
_supports_gradient_checkpointing = True
|
243
228
|
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
229
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
244
230
|
|
245
231
|
@register_to_config
|
246
232
|
def __init__(
|
@@ -259,7 +245,7 @@ class FluxTransformer2DModel(
|
|
259
245
|
):
|
260
246
|
super().__init__()
|
261
247
|
self.out_channels = out_channels or in_channels
|
262
|
-
self.inner_dim =
|
248
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
263
249
|
|
264
250
|
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
265
251
|
|
@@ -267,20 +253,20 @@ class FluxTransformer2DModel(
|
|
267
253
|
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
268
254
|
)
|
269
255
|
self.time_text_embed = text_time_guidance_cls(
|
270
|
-
embedding_dim=self.inner_dim, pooled_projection_dim=
|
256
|
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
271
257
|
)
|
272
258
|
|
273
|
-
self.context_embedder = nn.Linear(
|
274
|
-
self.x_embedder = nn.Linear(
|
259
|
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
260
|
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
275
261
|
|
276
262
|
self.transformer_blocks = nn.ModuleList(
|
277
263
|
[
|
278
264
|
FluxTransformerBlock(
|
279
265
|
dim=self.inner_dim,
|
280
|
-
num_attention_heads=
|
281
|
-
attention_head_dim=
|
266
|
+
num_attention_heads=num_attention_heads,
|
267
|
+
attention_head_dim=attention_head_dim,
|
282
268
|
)
|
283
|
-
for
|
269
|
+
for _ in range(num_layers)
|
284
270
|
]
|
285
271
|
)
|
286
272
|
|
@@ -288,10 +274,10 @@ class FluxTransformer2DModel(
|
|
288
274
|
[
|
289
275
|
FluxSingleTransformerBlock(
|
290
276
|
dim=self.inner_dim,
|
291
|
-
num_attention_heads=
|
292
|
-
attention_head_dim=
|
277
|
+
num_attention_heads=num_attention_heads,
|
278
|
+
attention_head_dim=attention_head_dim,
|
293
279
|
)
|
294
|
-
for
|
280
|
+
for _ in range(num_single_layers)
|
295
281
|
]
|
296
282
|
)
|
297
283
|
|
@@ -400,10 +386,6 @@ class FluxTransformer2DModel(
|
|
400
386
|
if self.original_attn_processors is not None:
|
401
387
|
self.set_attn_processor(self.original_attn_processors)
|
402
388
|
|
403
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
404
|
-
if hasattr(module, "gradient_checkpointing"):
|
405
|
-
module.gradient_checkpointing = value
|
406
|
-
|
407
389
|
def forward(
|
408
390
|
self,
|
409
391
|
hidden_states: torch.Tensor,
|
@@ -418,16 +400,16 @@ class FluxTransformer2DModel(
|
|
418
400
|
controlnet_single_block_samples=None,
|
419
401
|
return_dict: bool = True,
|
420
402
|
controlnet_blocks_repeat: bool = False,
|
421
|
-
) -> Union[torch.
|
403
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
422
404
|
"""
|
423
405
|
The [`FluxTransformer2DModel`] forward method.
|
424
406
|
|
425
407
|
Args:
|
426
|
-
hidden_states (`torch.
|
408
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
427
409
|
Input `hidden_states`.
|
428
|
-
encoder_hidden_states (`torch.
|
410
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
429
411
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
430
|
-
pooled_projections (`torch.
|
412
|
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
431
413
|
from the embeddings of input conditions.
|
432
414
|
timestep ( `torch.LongTensor`):
|
433
415
|
Used to indicate denoising step.
|
@@ -498,24 +480,12 @@ class FluxTransformer2DModel(
|
|
498
480
|
|
499
481
|
for index_block, block in enumerate(self.transformer_blocks):
|
500
482
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
501
|
-
|
502
|
-
|
503
|
-
def custom_forward(*inputs):
|
504
|
-
if return_dict is not None:
|
505
|
-
return module(*inputs, return_dict=return_dict)
|
506
|
-
else:
|
507
|
-
return module(*inputs)
|
508
|
-
|
509
|
-
return custom_forward
|
510
|
-
|
511
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
512
|
-
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
513
|
-
create_custom_forward(block),
|
483
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
484
|
+
block,
|
514
485
|
hidden_states,
|
515
486
|
encoder_hidden_states,
|
516
487
|
temb,
|
517
488
|
image_rotary_emb,
|
518
|
-
**ckpt_kwargs,
|
519
489
|
)
|
520
490
|
|
521
491
|
else:
|
@@ -542,23 +512,11 @@ class FluxTransformer2DModel(
|
|
542
512
|
|
543
513
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
544
514
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
545
|
-
|
546
|
-
|
547
|
-
def custom_forward(*inputs):
|
548
|
-
if return_dict is not None:
|
549
|
-
return module(*inputs, return_dict=return_dict)
|
550
|
-
else:
|
551
|
-
return module(*inputs)
|
552
|
-
|
553
|
-
return custom_forward
|
554
|
-
|
555
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
556
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
557
|
-
create_custom_forward(block),
|
515
|
+
hidden_states = self._gradient_checkpointing_func(
|
516
|
+
block,
|
558
517
|
hidden_states,
|
559
518
|
temb,
|
560
519
|
image_rotary_emb,
|
561
|
-
**ckpt_kwargs,
|
562
520
|
)
|
563
521
|
|
564
522
|
else:
|