diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
-
from ...utils import
|
21
|
+
from ...utils import logging
|
22
22
|
from ..attention import BasicTransformerBlock
|
23
23
|
from ..embeddings import PatchEmbed
|
24
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -64,7 +64,9 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
64
64
|
A small constant added to the denominator in normalization layers to prevent division by zero.
|
65
65
|
"""
|
66
66
|
|
67
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
67
68
|
_supports_gradient_checkpointing = True
|
69
|
+
_supports_group_offloading = False
|
68
70
|
|
69
71
|
@register_to_config
|
70
72
|
def __init__(
|
@@ -143,10 +145,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
143
145
|
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
|
144
146
|
)
|
145
147
|
|
146
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
147
|
-
if hasattr(module, "gradient_checkpointing"):
|
148
|
-
module.gradient_checkpointing = value
|
149
|
-
|
150
148
|
def forward(
|
151
149
|
self,
|
152
150
|
hidden_states: torch.Tensor,
|
@@ -185,19 +183,8 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
185
183
|
# 2. Blocks
|
186
184
|
for block in self.transformer_blocks:
|
187
185
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
188
|
-
|
189
|
-
|
190
|
-
def custom_forward(*inputs):
|
191
|
-
if return_dict is not None:
|
192
|
-
return module(*inputs, return_dict=return_dict)
|
193
|
-
else:
|
194
|
-
return module(*inputs)
|
195
|
-
|
196
|
-
return custom_forward
|
197
|
-
|
198
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
199
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
200
|
-
create_custom_forward(block),
|
186
|
+
hidden_states = self._gradient_checkpointing_func(
|
187
|
+
block,
|
201
188
|
hidden_states,
|
202
189
|
None,
|
203
190
|
None,
|
@@ -205,7 +192,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
205
192
|
timestep,
|
206
193
|
cross_attention_kwargs,
|
207
194
|
class_labels,
|
208
|
-
**ckpt_kwargs,
|
209
195
|
)
|
210
196
|
else:
|
211
197
|
hidden_states = block(
|
@@ -244,6 +244,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|
244
244
|
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
|
245
245
|
"""
|
246
246
|
|
247
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
|
248
|
+
_supports_group_offloading = False
|
249
|
+
|
247
250
|
@register_to_config
|
248
251
|
def __init__(
|
249
252
|
self,
|
@@ -277,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|
277
280
|
act_fn="silu_fp32",
|
278
281
|
)
|
279
282
|
|
280
|
-
self.text_embedding_padding = nn.Parameter(
|
281
|
-
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
|
282
|
-
)
|
283
|
+
self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
|
283
284
|
|
284
285
|
self.pos_embed = PatchEmbed(
|
285
286
|
height=sample_size,
|
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
14
15
|
from typing import Optional
|
15
16
|
|
16
17
|
import torch
|
@@ -19,13 +20,14 @@ from torch import nn
|
|
19
20
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
21
|
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
21
22
|
from ..attention import BasicTransformerBlock
|
23
|
+
from ..cache_utils import CacheMixin
|
22
24
|
from ..embeddings import PatchEmbed
|
23
25
|
from ..modeling_outputs import Transformer2DModelOutput
|
24
26
|
from ..modeling_utils import ModelMixin
|
25
27
|
from ..normalization import AdaLayerNormSingle
|
26
28
|
|
27
29
|
|
28
|
-
class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
30
|
+
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
29
31
|
_supports_gradient_checkpointing = True
|
30
32
|
|
31
33
|
"""
|
@@ -65,6 +67,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
65
67
|
The number of frames in the video-like data.
|
66
68
|
"""
|
67
69
|
|
70
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
71
|
+
|
68
72
|
@register_to_config
|
69
73
|
def __init__(
|
70
74
|
self,
|
@@ -162,9 +166,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
162
166
|
|
163
167
|
self.gradient_checkpointing = False
|
164
168
|
|
165
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
166
|
-
self.gradient_checkpointing = value
|
167
|
-
|
168
169
|
def forward(
|
169
170
|
self,
|
170
171
|
hidden_states: torch.Tensor,
|
@@ -226,20 +227,24 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
226
227
|
# Prepare text embeddings for spatial block
|
227
228
|
# batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
|
228
229
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
|
229
|
-
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
|
230
|
-
|
231
|
-
)
|
230
|
+
encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
|
231
|
+
num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
|
232
|
+
).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
|
232
233
|
|
233
234
|
# Prepare timesteps for spatial and temporal block
|
234
|
-
timestep_spatial = timestep.repeat_interleave(
|
235
|
-
|
235
|
+
timestep_spatial = timestep.repeat_interleave(
|
236
|
+
num_frame, dim=0, output_size=timestep.shape[0] * num_frame
|
237
|
+
).view(-1, timestep.shape[-1])
|
238
|
+
timestep_temp = timestep.repeat_interleave(
|
239
|
+
num_patches, dim=0, output_size=timestep.shape[0] * num_patches
|
240
|
+
).view(-1, timestep.shape[-1])
|
236
241
|
|
237
242
|
# Spatial and temporal transformer blocks
|
238
243
|
for i, (spatial_block, temp_block) in enumerate(
|
239
244
|
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
240
245
|
):
|
241
246
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
242
|
-
hidden_states =
|
247
|
+
hidden_states = self._gradient_checkpointing_func(
|
243
248
|
spatial_block,
|
244
249
|
hidden_states,
|
245
250
|
None, # attention_mask
|
@@ -248,7 +253,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
248
253
|
timestep_spatial,
|
249
254
|
None, # cross_attention_kwargs
|
250
255
|
None, # class_labels
|
251
|
-
use_reentrant=False,
|
252
256
|
)
|
253
257
|
else:
|
254
258
|
hidden_states = spatial_block(
|
@@ -269,10 +273,10 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
269
273
|
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
270
274
|
|
271
275
|
if i == 0 and num_frame > 1:
|
272
|
-
hidden_states = hidden_states + self.temp_pos_embed
|
276
|
+
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
|
273
277
|
|
274
278
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
275
|
-
hidden_states =
|
279
|
+
hidden_states = self._gradient_checkpointing_func(
|
276
280
|
temp_block,
|
277
281
|
hidden_states,
|
278
282
|
None, # attention_mask
|
@@ -281,7 +285,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
281
285
|
timestep_temp,
|
282
286
|
None, # cross_attention_kwargs
|
283
287
|
None, # class_labels
|
284
|
-
use_reentrant=False,
|
285
288
|
)
|
286
289
|
else:
|
287
290
|
hidden_states = temp_block(
|
@@ -300,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
|
300
303
|
).permute(0, 2, 1, 3)
|
301
304
|
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
|
302
305
|
|
303
|
-
embedded_timestep = embedded_timestep.repeat_interleave(
|
306
|
+
embedded_timestep = embedded_timestep.repeat_interleave(
|
307
|
+
num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
|
308
|
+
).view(-1, embedded_timestep.shape[-1])
|
304
309
|
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
305
310
|
hidden_states = self.norm_out(hidden_states)
|
306
311
|
# Modulation
|
@@ -98,7 +98,7 @@ class LuminaNextDiTBlock(nn.Module):
|
|
98
98
|
|
99
99
|
self.feed_forward = LuminaFeedForward(
|
100
100
|
dim=dim,
|
101
|
-
inner_dim=4 * dim,
|
101
|
+
inner_dim=int(4 * 2 * dim / 3),
|
102
102
|
multiple_of=multiple_of,
|
103
103
|
ffn_dim_multiplier=ffn_dim_multiplier,
|
104
104
|
)
|
@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
|
221
221
|
overall scale of the model's operations.
|
222
222
|
"""
|
223
223
|
|
224
|
+
_skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
|
225
|
+
|
224
226
|
@register_to_config
|
225
227
|
def __init__(
|
226
228
|
self,
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
from torch import nn
|
18
18
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
|
-
from ...utils import
|
20
|
+
from ...utils import logging
|
21
21
|
from ..attention import BasicTransformerBlock
|
22
22
|
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
23
23
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
79
79
|
|
80
80
|
_supports_gradient_checkpointing = True
|
81
81
|
_no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
|
82
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
|
82
83
|
|
83
84
|
@register_to_config
|
84
85
|
def __init__(
|
@@ -183,10 +184,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
183
184
|
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
184
185
|
)
|
185
186
|
|
186
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
187
|
-
if hasattr(module, "gradient_checkpointing"):
|
188
|
-
module.gradient_checkpointing = value
|
189
|
-
|
190
187
|
@property
|
191
188
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
192
189
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
@@ -387,19 +384,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
387
384
|
# 2. Blocks
|
388
385
|
for block in self.transformer_blocks:
|
389
386
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
390
|
-
|
391
|
-
|
392
|
-
def custom_forward(*inputs):
|
393
|
-
if return_dict is not None:
|
394
|
-
return module(*inputs, return_dict=return_dict)
|
395
|
-
else:
|
396
|
-
return module(*inputs)
|
397
|
-
|
398
|
-
return custom_forward
|
399
|
-
|
400
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
401
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
402
|
-
create_custom_forward(block),
|
387
|
+
hidden_states = self._gradient_checkpointing_func(
|
388
|
+
block,
|
403
389
|
hidden_states,
|
404
390
|
attention_mask,
|
405
391
|
encoder_hidden_states,
|
@@ -407,7 +393,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
407
393
|
timestep,
|
408
394
|
cross_attention_kwargs,
|
409
395
|
None,
|
410
|
-
**ckpt_kwargs,
|
411
396
|
)
|
412
397
|
else:
|
413
398
|
hidden_states = block(
|
@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
353
353
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
354
354
|
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
355
355
|
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
356
|
-
attention_mask = attention_mask.repeat_interleave(
|
356
|
+
attention_mask = attention_mask.repeat_interleave(
|
357
|
+
self.config.num_attention_heads,
|
358
|
+
dim=0,
|
359
|
+
output_size=attention_mask.shape[0] * self.config.num_attention_heads,
|
360
|
+
)
|
357
361
|
|
358
362
|
if self.norm_in is not None:
|
359
363
|
hidden_states = self.norm_in(hidden_states)
|
@@ -15,18 +15,18 @@
|
|
15
15
|
from typing import Any, Dict, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
+
import torch.nn.functional as F
|
18
19
|
from torch import nn
|
19
20
|
|
20
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
-
from ...loaders import PeftAdapterMixin
|
22
|
-
from ...utils import USE_PEFT_BACKEND,
|
22
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
23
24
|
from ..attention_processor import (
|
24
25
|
Attention,
|
25
26
|
AttentionProcessor,
|
26
|
-
AttnProcessor2_0,
|
27
27
|
SanaLinearAttnProcessor2_0,
|
28
28
|
)
|
29
|
-
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
29
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
30
30
|
from ..modeling_outputs import Transformer2DModelOutput
|
31
31
|
from ..modeling_utils import ModelMixin
|
32
32
|
from ..normalization import AdaLayerNormSingle, RMSNorm
|
@@ -82,6 +82,109 @@ class GLUMBConv(nn.Module):
|
|
82
82
|
return hidden_states
|
83
83
|
|
84
84
|
|
85
|
+
class SanaModulatedNorm(nn.Module):
|
86
|
+
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
|
87
|
+
super().__init__()
|
88
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
|
89
|
+
|
90
|
+
def forward(
|
91
|
+
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
|
92
|
+
) -> torch.Tensor:
|
93
|
+
hidden_states = self.norm(hidden_states)
|
94
|
+
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
|
95
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
96
|
+
return hidden_states
|
97
|
+
|
98
|
+
|
99
|
+
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
|
100
|
+
def __init__(self, embedding_dim):
|
101
|
+
super().__init__()
|
102
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
103
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
104
|
+
|
105
|
+
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
106
|
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
107
|
+
|
108
|
+
self.silu = nn.SiLU()
|
109
|
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
110
|
+
|
111
|
+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
|
112
|
+
timesteps_proj = self.time_proj(timestep)
|
113
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
114
|
+
|
115
|
+
guidance_proj = self.guidance_condition_proj(guidance)
|
116
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
|
117
|
+
conditioning = timesteps_emb + guidance_emb
|
118
|
+
|
119
|
+
return self.linear(self.silu(conditioning)), conditioning
|
120
|
+
|
121
|
+
|
122
|
+
class SanaAttnProcessor2_0:
|
123
|
+
r"""
|
124
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
125
|
+
"""
|
126
|
+
|
127
|
+
def __init__(self):
|
128
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
129
|
+
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
130
|
+
|
131
|
+
def __call__(
|
132
|
+
self,
|
133
|
+
attn: Attention,
|
134
|
+
hidden_states: torch.Tensor,
|
135
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
136
|
+
attention_mask: Optional[torch.Tensor] = None,
|
137
|
+
) -> torch.Tensor:
|
138
|
+
batch_size, sequence_length, _ = (
|
139
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
140
|
+
)
|
141
|
+
|
142
|
+
if attention_mask is not None:
|
143
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
144
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
145
|
+
# (batch, heads, source_length, target_length)
|
146
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
147
|
+
|
148
|
+
query = attn.to_q(hidden_states)
|
149
|
+
|
150
|
+
if encoder_hidden_states is None:
|
151
|
+
encoder_hidden_states = hidden_states
|
152
|
+
|
153
|
+
key = attn.to_k(encoder_hidden_states)
|
154
|
+
value = attn.to_v(encoder_hidden_states)
|
155
|
+
|
156
|
+
if attn.norm_q is not None:
|
157
|
+
query = attn.norm_q(query)
|
158
|
+
if attn.norm_k is not None:
|
159
|
+
key = attn.norm_k(key)
|
160
|
+
|
161
|
+
inner_dim = key.shape[-1]
|
162
|
+
head_dim = inner_dim // attn.heads
|
163
|
+
|
164
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
165
|
+
|
166
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
167
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
168
|
+
|
169
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
170
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
171
|
+
hidden_states = F.scaled_dot_product_attention(
|
172
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
173
|
+
)
|
174
|
+
|
175
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
176
|
+
hidden_states = hidden_states.to(query.dtype)
|
177
|
+
|
178
|
+
# linear proj
|
179
|
+
hidden_states = attn.to_out[0](hidden_states)
|
180
|
+
# dropout
|
181
|
+
hidden_states = attn.to_out[1](hidden_states)
|
182
|
+
|
183
|
+
hidden_states = hidden_states / attn.rescale_output_factor
|
184
|
+
|
185
|
+
return hidden_states
|
186
|
+
|
187
|
+
|
85
188
|
class SanaTransformerBlock(nn.Module):
|
86
189
|
r"""
|
87
190
|
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
|
@@ -101,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
|
|
101
204
|
norm_eps: float = 1e-6,
|
102
205
|
attention_out_bias: bool = True,
|
103
206
|
mlp_ratio: float = 2.5,
|
207
|
+
qk_norm: Optional[str] = None,
|
104
208
|
) -> None:
|
105
209
|
super().__init__()
|
106
210
|
|
@@ -110,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
|
|
110
214
|
query_dim=dim,
|
111
215
|
heads=num_attention_heads,
|
112
216
|
dim_head=attention_head_dim,
|
217
|
+
kv_heads=num_attention_heads if qk_norm is not None else None,
|
218
|
+
qk_norm=qk_norm,
|
113
219
|
dropout=dropout,
|
114
220
|
bias=attention_bias,
|
115
221
|
cross_attention_dim=None,
|
@@ -121,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
|
|
121
227
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
122
228
|
self.attn2 = Attention(
|
123
229
|
query_dim=dim,
|
230
|
+
qk_norm=qk_norm,
|
231
|
+
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
|
124
232
|
cross_attention_dim=cross_attention_dim,
|
125
233
|
heads=num_cross_attention_heads,
|
126
234
|
dim_head=cross_attention_head_dim,
|
127
235
|
dropout=dropout,
|
128
236
|
bias=True,
|
129
237
|
out_bias=attention_out_bias,
|
130
|
-
processor=
|
238
|
+
processor=SanaAttnProcessor2_0(),
|
131
239
|
)
|
132
240
|
|
133
241
|
# 3. Feed-forward
|
@@ -181,7 +289,7 @@ class SanaTransformerBlock(nn.Module):
|
|
181
289
|
return hidden_states
|
182
290
|
|
183
291
|
|
184
|
-
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
292
|
+
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
185
293
|
r"""
|
186
294
|
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
187
295
|
|
@@ -218,10 +326,15 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
218
326
|
Whether to use elementwise affinity in the normalization layer.
|
219
327
|
norm_eps (`float`, defaults to `1e-6`):
|
220
328
|
The epsilon value for the normalization layer.
|
329
|
+
qk_norm (`str`, *optional*, defaults to `None`):
|
330
|
+
The normalization to use for the query and key.
|
331
|
+
timestep_scale (`float`, defaults to `1.0`):
|
332
|
+
The scale to use for the timesteps.
|
221
333
|
"""
|
222
334
|
|
223
335
|
_supports_gradient_checkpointing = True
|
224
|
-
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
|
336
|
+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
|
337
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
225
338
|
|
226
339
|
@register_to_config
|
227
340
|
def __init__(
|
@@ -243,6 +356,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
243
356
|
norm_elementwise_affine: bool = False,
|
244
357
|
norm_eps: float = 1e-6,
|
245
358
|
interpolation_scale: Optional[int] = None,
|
359
|
+
guidance_embeds: bool = False,
|
360
|
+
guidance_embeds_scale: float = 0.1,
|
361
|
+
qk_norm: Optional[str] = None,
|
362
|
+
timestep_scale: float = 1.0,
|
246
363
|
) -> None:
|
247
364
|
super().__init__()
|
248
365
|
|
@@ -250,7 +367,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
250
367
|
inner_dim = num_attention_heads * attention_head_dim
|
251
368
|
|
252
369
|
# 1. Patch Embedding
|
253
|
-
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
|
254
370
|
self.patch_embed = PatchEmbed(
|
255
371
|
height=sample_size,
|
256
372
|
width=sample_size,
|
@@ -258,10 +374,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
258
374
|
in_channels=in_channels,
|
259
375
|
embed_dim=inner_dim,
|
260
376
|
interpolation_scale=interpolation_scale,
|
377
|
+
pos_embed_type="sincos" if interpolation_scale is not None else None,
|
261
378
|
)
|
262
379
|
|
263
380
|
# 2. Additional condition embeddings
|
264
|
-
|
381
|
+
if guidance_embeds:
|
382
|
+
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
|
383
|
+
else:
|
384
|
+
self.time_embed = AdaLayerNormSingle(inner_dim)
|
265
385
|
|
266
386
|
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
267
387
|
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
@@ -281,6 +401,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
281
401
|
norm_elementwise_affine=norm_elementwise_affine,
|
282
402
|
norm_eps=norm_eps,
|
283
403
|
mlp_ratio=mlp_ratio,
|
404
|
+
qk_norm=qk_norm,
|
284
405
|
)
|
285
406
|
for _ in range(num_layers)
|
286
407
|
]
|
@@ -288,16 +409,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
288
409
|
|
289
410
|
# 4. Output blocks
|
290
411
|
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
291
|
-
|
292
|
-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
412
|
+
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
293
413
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
294
414
|
|
295
415
|
self.gradient_checkpointing = False
|
296
416
|
|
297
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
298
|
-
if hasattr(module, "gradient_checkpointing"):
|
299
|
-
module.gradient_checkpointing = value
|
300
|
-
|
301
417
|
@property
|
302
418
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
303
419
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
@@ -362,7 +478,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
362
478
|
self,
|
363
479
|
hidden_states: torch.Tensor,
|
364
480
|
encoder_hidden_states: torch.Tensor,
|
365
|
-
timestep: torch.
|
481
|
+
timestep: torch.Tensor,
|
482
|
+
guidance: Optional[torch.Tensor] = None,
|
366
483
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
367
484
|
attention_mask: Optional[torch.Tensor] = None,
|
368
485
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
@@ -413,9 +530,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
413
530
|
|
414
531
|
hidden_states = self.patch_embed(hidden_states)
|
415
532
|
|
416
|
-
|
417
|
-
timestep,
|
418
|
-
|
533
|
+
if guidance is not None:
|
534
|
+
timestep, embedded_timestep = self.time_embed(
|
535
|
+
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
536
|
+
)
|
537
|
+
else:
|
538
|
+
timestep, embedded_timestep = self.time_embed(
|
539
|
+
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
540
|
+
)
|
419
541
|
|
420
542
|
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
421
543
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
@@ -424,21 +546,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
424
546
|
|
425
547
|
# 2. Transformer blocks
|
426
548
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
427
|
-
|
428
|
-
def create_custom_forward(module, return_dict=None):
|
429
|
-
def custom_forward(*inputs):
|
430
|
-
if return_dict is not None:
|
431
|
-
return module(*inputs, return_dict=return_dict)
|
432
|
-
else:
|
433
|
-
return module(*inputs)
|
434
|
-
|
435
|
-
return custom_forward
|
436
|
-
|
437
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
438
|
-
|
439
549
|
for block in self.transformer_blocks:
|
440
|
-
hidden_states =
|
441
|
-
|
550
|
+
hidden_states = self._gradient_checkpointing_func(
|
551
|
+
block,
|
442
552
|
hidden_states,
|
443
553
|
attention_mask,
|
444
554
|
encoder_hidden_states,
|
@@ -446,7 +556,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
446
556
|
timestep,
|
447
557
|
post_patch_height,
|
448
558
|
post_patch_width,
|
449
|
-
**ckpt_kwargs,
|
450
559
|
)
|
451
560
|
|
452
561
|
else:
|
@@ -462,13 +571,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
462
571
|
)
|
463
572
|
|
464
573
|
# 3. Normalization
|
465
|
-
|
466
|
-
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
467
|
-
).chunk(2, dim=1)
|
468
|
-
hidden_states = self.norm_out(hidden_states)
|
574
|
+
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
469
575
|
|
470
|
-
# 4. Modulation
|
471
|
-
hidden_states = hidden_states * (1 + scale) + shift
|
472
576
|
hidden_states = self.proj_out(hidden_states)
|
473
577
|
|
474
578
|
# 5. Unpatchify
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Dict, Optional, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
@@ -29,7 +29,7 @@ from ...models.attention_processor import (
|
|
29
29
|
)
|
30
30
|
from ...models.modeling_utils import ModelMixin
|
31
31
|
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
32
|
-
from ...utils import
|
32
|
+
from ...utils import logging
|
33
33
|
from ...utils.torch_utils import maybe_allow_in_graph
|
34
34
|
|
35
35
|
|
@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
|
211
211
|
"""
|
212
212
|
|
213
213
|
_supports_gradient_checkpointing = True
|
214
|
+
_skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
|
214
215
|
|
215
216
|
@register_to_config
|
216
217
|
def __init__(
|
@@ -345,10 +346,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
|
345
346
|
"""
|
346
347
|
self.set_attn_processor(StableAudioAttnProcessor2_0())
|
347
348
|
|
348
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
349
|
-
if hasattr(module, "gradient_checkpointing"):
|
350
|
-
module.gradient_checkpointing = value
|
351
|
-
|
352
349
|
def forward(
|
353
350
|
self,
|
354
351
|
hidden_states: torch.FloatTensor,
|
@@ -415,25 +412,13 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
|
415
412
|
|
416
413
|
for block in self.transformer_blocks:
|
417
414
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
418
|
-
|
419
|
-
|
420
|
-
def custom_forward(*inputs):
|
421
|
-
if return_dict is not None:
|
422
|
-
return module(*inputs, return_dict=return_dict)
|
423
|
-
else:
|
424
|
-
return module(*inputs)
|
425
|
-
|
426
|
-
return custom_forward
|
427
|
-
|
428
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
429
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
430
|
-
create_custom_forward(block),
|
415
|
+
hidden_states = self._gradient_checkpointing_func(
|
416
|
+
block,
|
431
417
|
hidden_states,
|
432
418
|
attention_mask,
|
433
419
|
cross_attention_hidden_states,
|
434
420
|
encoder_attention_mask,
|
435
421
|
rotary_embedding,
|
436
|
-
**ckpt_kwargs,
|
437
422
|
)
|
438
423
|
|
439
424
|
else:
|