diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +595 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
- diffusers-0.33.1.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
21
|
-
from ...utils import deprecate,
|
21
|
+
from ...utils import deprecate, logging
|
22
22
|
from ..attention import BasicTransformerBlock
|
23
23
|
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
24
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -66,6 +66,7 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
66
66
|
|
67
67
|
_supports_gradient_checkpointing = True
|
68
68
|
_no_split_modules = ["BasicTransformerBlock"]
|
69
|
+
_skip_layerwise_casting_patterns = ["latent_image_embedding", "norm"]
|
69
70
|
|
70
71
|
@register_to_config
|
71
72
|
def __init__(
|
@@ -210,9 +211,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
210
211
|
|
211
212
|
def _init_vectorized_inputs(self, norm_type):
|
212
213
|
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
213
|
-
assert (
|
214
|
-
|
215
|
-
)
|
214
|
+
assert self.config.num_vector_embeds is not None, (
|
215
|
+
"Transformer2DModel over discrete input must provide num_embed"
|
216
|
+
)
|
216
217
|
|
217
218
|
self.height = self.config.sample_size
|
218
219
|
self.width = self.config.sample_size
|
@@ -320,10 +321,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
320
321
|
in_features=self.caption_channels, hidden_size=self.inner_dim
|
321
322
|
)
|
322
323
|
|
323
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
324
|
-
if hasattr(module, "gradient_checkpointing"):
|
325
|
-
module.gradient_checkpointing = value
|
326
|
-
|
327
324
|
def forward(
|
328
325
|
self,
|
329
326
|
hidden_states: torch.Tensor,
|
@@ -416,19 +413,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
416
413
|
# 2. Blocks
|
417
414
|
for block in self.transformer_blocks:
|
418
415
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
419
|
-
|
420
|
-
|
421
|
-
def custom_forward(*inputs):
|
422
|
-
if return_dict is not None:
|
423
|
-
return module(*inputs, return_dict=return_dict)
|
424
|
-
else:
|
425
|
-
return module(*inputs)
|
426
|
-
|
427
|
-
return custom_forward
|
428
|
-
|
429
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
430
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
431
|
-
create_custom_forward(block),
|
416
|
+
hidden_states = self._gradient_checkpointing_func(
|
417
|
+
block,
|
432
418
|
hidden_states,
|
433
419
|
attention_mask,
|
434
420
|
encoder_hidden_states,
|
@@ -436,7 +422,6 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
436
422
|
timestep,
|
437
423
|
cross_attention_kwargs,
|
438
424
|
class_labels,
|
439
|
-
**ckpt_kwargs,
|
440
425
|
)
|
441
426
|
else:
|
442
427
|
hidden_states = block(
|
@@ -13,17 +13,18 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ...utils import
|
23
|
+
from ...utils import logging
|
24
24
|
from ...utils.torch_utils import maybe_allow_in_graph
|
25
25
|
from ..attention import FeedForward
|
26
26
|
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
27
|
+
from ..cache_utils import CacheMixin
|
27
28
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
28
29
|
from ..modeling_outputs import Transformer2DModelOutput
|
29
30
|
from ..modeling_utils import ModelMixin
|
@@ -172,7 +173,7 @@ class AllegroTransformerBlock(nn.Module):
|
|
172
173
|
return hidden_states
|
173
174
|
|
174
175
|
|
175
|
-
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
176
|
+
class AllegroTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
176
177
|
_supports_gradient_checkpointing = True
|
177
178
|
|
178
179
|
"""
|
@@ -221,6 +222,9 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
|
221
222
|
Scaling factor to apply in 3D positional embeddings across time dimension.
|
222
223
|
"""
|
223
224
|
|
225
|
+
_supports_gradient_checkpointing = True
|
226
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
|
227
|
+
|
224
228
|
@register_to_config
|
225
229
|
def __init__(
|
226
230
|
self,
|
@@ -300,9 +304,6 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
|
300
304
|
|
301
305
|
self.gradient_checkpointing = False
|
302
306
|
|
303
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
304
|
-
self.gradient_checkpointing = value
|
305
|
-
|
306
307
|
def forward(
|
307
308
|
self,
|
308
309
|
hidden_states: torch.Tensor,
|
@@ -372,23 +373,14 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
|
|
372
373
|
for i, block in enumerate(self.transformer_blocks):
|
373
374
|
# TODO(aryan): Implement gradient checkpointing
|
374
375
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
375
|
-
|
376
|
-
|
377
|
-
def custom_forward(*inputs):
|
378
|
-
return module(*inputs)
|
379
|
-
|
380
|
-
return custom_forward
|
381
|
-
|
382
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
383
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
384
|
-
create_custom_forward(block),
|
376
|
+
hidden_states = self._gradient_checkpointing_func(
|
377
|
+
block,
|
385
378
|
hidden_states,
|
386
379
|
encoder_hidden_states,
|
387
380
|
timestep,
|
388
381
|
attention_mask,
|
389
382
|
encoder_attention_mask,
|
390
383
|
image_rotary_emb,
|
391
|
-
**ckpt_kwargs,
|
392
384
|
)
|
393
385
|
else:
|
394
386
|
hidden_states = block(
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Dict, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
@@ -27,7 +27,7 @@ from ...models.attention_processor import (
|
|
27
27
|
)
|
28
28
|
from ...models.modeling_utils import ModelMixin
|
29
29
|
from ...models.normalization import AdaLayerNormContinuous
|
30
|
-
from ...utils import
|
30
|
+
from ...utils import logging
|
31
31
|
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
32
32
|
from ..modeling_outputs import Transformer2DModelOutput
|
33
33
|
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
@@ -166,6 +166,8 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
|
166
166
|
"""
|
167
167
|
|
168
168
|
_supports_gradient_checkpointing = True
|
169
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
170
|
+
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
|
169
171
|
|
170
172
|
@register_to_config
|
171
173
|
def __init__(
|
@@ -287,10 +289,6 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
|
287
289
|
for name, module in self.named_children():
|
288
290
|
fn_recursive_attn_processor(name, module, processor)
|
289
291
|
|
290
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
291
|
-
if hasattr(module, "gradient_checkpointing"):
|
292
|
-
module.gradient_checkpointing = value
|
293
|
-
|
294
292
|
def forward(
|
295
293
|
self,
|
296
294
|
hidden_states: torch.Tensor,
|
@@ -342,20 +340,11 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
|
342
340
|
|
343
341
|
for index_block, block in enumerate(self.transformer_blocks):
|
344
342
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
345
|
-
|
346
|
-
|
347
|
-
def custom_forward(*inputs):
|
348
|
-
return module(*inputs)
|
349
|
-
|
350
|
-
return custom_forward
|
351
|
-
|
352
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
353
|
-
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
354
|
-
create_custom_forward(block),
|
343
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
344
|
+
block,
|
355
345
|
hidden_states,
|
356
346
|
encoder_hidden_states,
|
357
347
|
emb,
|
358
|
-
**ckpt_kwargs,
|
359
348
|
)
|
360
349
|
else:
|
361
350
|
hidden_states, encoder_hidden_states = block(
|
@@ -0,0 +1,462 @@
|
|
1
|
+
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn as nn
|
19
|
+
import torch.nn.functional as F
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ..attention import FeedForward
|
25
|
+
from ..attention_processor import Attention
|
26
|
+
from ..cache_utils import CacheMixin
|
27
|
+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
28
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
29
|
+
from ..modeling_utils import ModelMixin
|
30
|
+
from ..normalization import AdaLayerNormContinuous
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
class CogView4PatchEmbed(nn.Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
in_channels: int = 16,
|
40
|
+
hidden_size: int = 2560,
|
41
|
+
patch_size: int = 2,
|
42
|
+
text_hidden_size: int = 4096,
|
43
|
+
):
|
44
|
+
super().__init__()
|
45
|
+
self.patch_size = patch_size
|
46
|
+
|
47
|
+
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
48
|
+
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
49
|
+
|
50
|
+
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
51
|
+
batch_size, channel, height, width = hidden_states.shape
|
52
|
+
post_patch_height = height // self.patch_size
|
53
|
+
post_patch_width = width // self.patch_size
|
54
|
+
|
55
|
+
hidden_states = hidden_states.reshape(
|
56
|
+
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
|
57
|
+
)
|
58
|
+
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
59
|
+
hidden_states = self.proj(hidden_states)
|
60
|
+
encoder_hidden_states = self.text_proj(encoder_hidden_states)
|
61
|
+
|
62
|
+
return hidden_states, encoder_hidden_states
|
63
|
+
|
64
|
+
|
65
|
+
class CogView4AdaLayerNormZero(nn.Module):
|
66
|
+
def __init__(self, embedding_dim: int, dim: int) -> None:
|
67
|
+
super().__init__()
|
68
|
+
|
69
|
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
70
|
+
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
71
|
+
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
75
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
+
norm_hidden_states = self.norm(hidden_states)
|
77
|
+
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
|
78
|
+
|
79
|
+
emb = self.linear(temb)
|
80
|
+
(
|
81
|
+
shift_msa,
|
82
|
+
c_shift_msa,
|
83
|
+
scale_msa,
|
84
|
+
c_scale_msa,
|
85
|
+
gate_msa,
|
86
|
+
c_gate_msa,
|
87
|
+
shift_mlp,
|
88
|
+
c_shift_mlp,
|
89
|
+
scale_mlp,
|
90
|
+
c_scale_mlp,
|
91
|
+
gate_mlp,
|
92
|
+
c_gate_mlp,
|
93
|
+
) = emb.chunk(12, dim=1)
|
94
|
+
|
95
|
+
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
96
|
+
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
97
|
+
|
98
|
+
return (
|
99
|
+
hidden_states,
|
100
|
+
gate_msa,
|
101
|
+
shift_mlp,
|
102
|
+
scale_mlp,
|
103
|
+
gate_mlp,
|
104
|
+
encoder_hidden_states,
|
105
|
+
c_gate_msa,
|
106
|
+
c_shift_mlp,
|
107
|
+
c_scale_mlp,
|
108
|
+
c_gate_mlp,
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class CogView4AttnProcessor:
|
113
|
+
"""
|
114
|
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
115
|
+
query and key vectors, but does not include spatial normalization.
|
116
|
+
"""
|
117
|
+
|
118
|
+
def __init__(self):
|
119
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
120
|
+
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
121
|
+
|
122
|
+
def __call__(
|
123
|
+
self,
|
124
|
+
attn: Attention,
|
125
|
+
hidden_states: torch.Tensor,
|
126
|
+
encoder_hidden_states: torch.Tensor,
|
127
|
+
attention_mask: Optional[torch.Tensor] = None,
|
128
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
129
|
+
) -> torch.Tensor:
|
130
|
+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
131
|
+
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
132
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
133
|
+
|
134
|
+
# 1. QKV projections
|
135
|
+
query = attn.to_q(hidden_states)
|
136
|
+
key = attn.to_k(hidden_states)
|
137
|
+
value = attn.to_v(hidden_states)
|
138
|
+
|
139
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
140
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
141
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
142
|
+
|
143
|
+
# 2. QK normalization
|
144
|
+
if attn.norm_q is not None:
|
145
|
+
query = attn.norm_q(query)
|
146
|
+
if attn.norm_k is not None:
|
147
|
+
key = attn.norm_k(key)
|
148
|
+
|
149
|
+
# 3. Rotational positional embeddings applied to latent stream
|
150
|
+
if image_rotary_emb is not None:
|
151
|
+
from ..embeddings import apply_rotary_emb
|
152
|
+
|
153
|
+
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
154
|
+
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
155
|
+
)
|
156
|
+
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
157
|
+
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
158
|
+
)
|
159
|
+
|
160
|
+
# 4. Attention
|
161
|
+
if attention_mask is not None:
|
162
|
+
text_attention_mask = attention_mask.float().to(query.device)
|
163
|
+
actual_text_seq_length = text_attention_mask.size(1)
|
164
|
+
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
|
165
|
+
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
|
166
|
+
new_attention_mask = new_attention_mask.unsqueeze(2)
|
167
|
+
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
|
168
|
+
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
169
|
+
|
170
|
+
hidden_states = F.scaled_dot_product_attention(
|
171
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
172
|
+
)
|
173
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
174
|
+
hidden_states = hidden_states.type_as(query)
|
175
|
+
|
176
|
+
# 5. Output projection
|
177
|
+
hidden_states = attn.to_out[0](hidden_states)
|
178
|
+
hidden_states = attn.to_out[1](hidden_states)
|
179
|
+
|
180
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
181
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
182
|
+
)
|
183
|
+
return hidden_states, encoder_hidden_states
|
184
|
+
|
185
|
+
|
186
|
+
class CogView4TransformerBlock(nn.Module):
|
187
|
+
def __init__(
|
188
|
+
self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
|
189
|
+
) -> None:
|
190
|
+
super().__init__()
|
191
|
+
|
192
|
+
# 1. Attention
|
193
|
+
self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim)
|
194
|
+
self.attn1 = Attention(
|
195
|
+
query_dim=dim,
|
196
|
+
heads=num_attention_heads,
|
197
|
+
dim_head=attention_head_dim,
|
198
|
+
out_dim=dim,
|
199
|
+
bias=True,
|
200
|
+
qk_norm="layer_norm",
|
201
|
+
elementwise_affine=False,
|
202
|
+
eps=1e-5,
|
203
|
+
processor=CogView4AttnProcessor(),
|
204
|
+
)
|
205
|
+
|
206
|
+
# 2. Feedforward
|
207
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
208
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
209
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
210
|
+
|
211
|
+
def forward(
|
212
|
+
self,
|
213
|
+
hidden_states: torch.Tensor,
|
214
|
+
encoder_hidden_states: torch.Tensor,
|
215
|
+
temb: Optional[torch.Tensor] = None,
|
216
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
217
|
+
attention_mask: Optional[torch.Tensor] = None,
|
218
|
+
**kwargs,
|
219
|
+
) -> torch.Tensor:
|
220
|
+
# 1. Timestep conditioning
|
221
|
+
(
|
222
|
+
norm_hidden_states,
|
223
|
+
gate_msa,
|
224
|
+
shift_mlp,
|
225
|
+
scale_mlp,
|
226
|
+
gate_mlp,
|
227
|
+
norm_encoder_hidden_states,
|
228
|
+
c_gate_msa,
|
229
|
+
c_shift_mlp,
|
230
|
+
c_scale_mlp,
|
231
|
+
c_gate_mlp,
|
232
|
+
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
233
|
+
|
234
|
+
# 2. Attention
|
235
|
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
236
|
+
hidden_states=norm_hidden_states,
|
237
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
238
|
+
image_rotary_emb=image_rotary_emb,
|
239
|
+
attention_mask=attention_mask,
|
240
|
+
**kwargs,
|
241
|
+
)
|
242
|
+
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
243
|
+
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
244
|
+
|
245
|
+
# 3. Feedforward
|
246
|
+
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
247
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
|
248
|
+
1 + c_scale_mlp.unsqueeze(1)
|
249
|
+
) + c_shift_mlp.unsqueeze(1)
|
250
|
+
|
251
|
+
ff_output = self.ff(norm_hidden_states)
|
252
|
+
ff_output_context = self.ff(norm_encoder_hidden_states)
|
253
|
+
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
|
254
|
+
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
255
|
+
|
256
|
+
return hidden_states, encoder_hidden_states
|
257
|
+
|
258
|
+
|
259
|
+
class CogView4RotaryPosEmbed(nn.Module):
|
260
|
+
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
|
261
|
+
super().__init__()
|
262
|
+
|
263
|
+
self.dim = dim
|
264
|
+
self.patch_size = patch_size
|
265
|
+
self.rope_axes_dim = rope_axes_dim
|
266
|
+
self.theta = theta
|
267
|
+
|
268
|
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
269
|
+
batch_size, num_channels, height, width = hidden_states.shape
|
270
|
+
height, width = height // self.patch_size, width // self.patch_size
|
271
|
+
|
272
|
+
dim_h, dim_w = self.dim // 2, self.dim // 2
|
273
|
+
h_inv_freq = 1.0 / (
|
274
|
+
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
275
|
+
)
|
276
|
+
w_inv_freq = 1.0 / (
|
277
|
+
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
278
|
+
)
|
279
|
+
h_seq = torch.arange(self.rope_axes_dim[0])
|
280
|
+
w_seq = torch.arange(self.rope_axes_dim[1])
|
281
|
+
freqs_h = torch.outer(h_seq, h_inv_freq)
|
282
|
+
freqs_w = torch.outer(w_seq, w_inv_freq)
|
283
|
+
|
284
|
+
h_idx = torch.arange(height, device=freqs_h.device)
|
285
|
+
w_idx = torch.arange(width, device=freqs_w.device)
|
286
|
+
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
|
287
|
+
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
|
288
|
+
|
289
|
+
freqs_h = freqs_h[inner_h_idx]
|
290
|
+
freqs_w = freqs_w[inner_w_idx]
|
291
|
+
|
292
|
+
# Create position matrices for height and width
|
293
|
+
# [height, 1, dim//4] and [1, width, dim//4]
|
294
|
+
freqs_h = freqs_h.unsqueeze(1)
|
295
|
+
freqs_w = freqs_w.unsqueeze(0)
|
296
|
+
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
|
297
|
+
freqs_h = freqs_h.expand(height, width, -1)
|
298
|
+
freqs_w = freqs_w.expand(height, width, -1)
|
299
|
+
|
300
|
+
# Concatenate along last dimension to get [height, width, dim//2]
|
301
|
+
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
302
|
+
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
|
303
|
+
freqs = freqs.reshape(height * width, -1)
|
304
|
+
return (freqs.cos(), freqs.sin())
|
305
|
+
|
306
|
+
|
307
|
+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
308
|
+
r"""
|
309
|
+
Args:
|
310
|
+
patch_size (`int`, defaults to `2`):
|
311
|
+
The size of the patches to use in the patch embedding layer.
|
312
|
+
in_channels (`int`, defaults to `16`):
|
313
|
+
The number of channels in the input.
|
314
|
+
num_layers (`int`, defaults to `30`):
|
315
|
+
The number of layers of Transformer blocks to use.
|
316
|
+
attention_head_dim (`int`, defaults to `40`):
|
317
|
+
The number of channels in each head.
|
318
|
+
num_attention_heads (`int`, defaults to `64`):
|
319
|
+
The number of heads to use for multi-head attention.
|
320
|
+
out_channels (`int`, defaults to `16`):
|
321
|
+
The number of channels in the output.
|
322
|
+
text_embed_dim (`int`, defaults to `4096`):
|
323
|
+
Input dimension of text embeddings from the text encoder.
|
324
|
+
time_embed_dim (`int`, defaults to `512`):
|
325
|
+
Output dimension of timestep embeddings.
|
326
|
+
condition_dim (`int`, defaults to `256`):
|
327
|
+
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
328
|
+
crop_coords).
|
329
|
+
pos_embed_max_size (`int`, defaults to `128`):
|
330
|
+
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
331
|
+
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
332
|
+
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
333
|
+
patch_size => 128 * 8 * 2 => 2048`.
|
334
|
+
sample_size (`int`, defaults to `128`):
|
335
|
+
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
336
|
+
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
337
|
+
"""
|
338
|
+
|
339
|
+
_supports_gradient_checkpointing = True
|
340
|
+
_no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"]
|
341
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
|
342
|
+
|
343
|
+
@register_to_config
|
344
|
+
def __init__(
|
345
|
+
self,
|
346
|
+
patch_size: int = 2,
|
347
|
+
in_channels: int = 16,
|
348
|
+
out_channels: int = 16,
|
349
|
+
num_layers: int = 30,
|
350
|
+
attention_head_dim: int = 40,
|
351
|
+
num_attention_heads: int = 64,
|
352
|
+
text_embed_dim: int = 4096,
|
353
|
+
time_embed_dim: int = 512,
|
354
|
+
condition_dim: int = 256,
|
355
|
+
pos_embed_max_size: int = 128,
|
356
|
+
sample_size: int = 128,
|
357
|
+
rope_axes_dim: Tuple[int, int] = (256, 256),
|
358
|
+
):
|
359
|
+
super().__init__()
|
360
|
+
|
361
|
+
# CogView4 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
|
362
|
+
# Each of these are sincos embeddings of shape 2 * condition_dim
|
363
|
+
pooled_projection_dim = 3 * 2 * condition_dim
|
364
|
+
inner_dim = num_attention_heads * attention_head_dim
|
365
|
+
out_channels = out_channels
|
366
|
+
|
367
|
+
# 1. RoPE
|
368
|
+
self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0)
|
369
|
+
|
370
|
+
# 2. Patch & Text-timestep embedding
|
371
|
+
self.patch_embed = CogView4PatchEmbed(in_channels, inner_dim, patch_size, text_embed_dim)
|
372
|
+
|
373
|
+
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
|
374
|
+
embedding_dim=time_embed_dim,
|
375
|
+
condition_dim=condition_dim,
|
376
|
+
pooled_projection_dim=pooled_projection_dim,
|
377
|
+
timesteps_dim=inner_dim,
|
378
|
+
)
|
379
|
+
|
380
|
+
# 3. Transformer blocks
|
381
|
+
self.transformer_blocks = nn.ModuleList(
|
382
|
+
[
|
383
|
+
CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
384
|
+
for _ in range(num_layers)
|
385
|
+
]
|
386
|
+
)
|
387
|
+
|
388
|
+
# 4. Output projection
|
389
|
+
self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
390
|
+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
391
|
+
|
392
|
+
self.gradient_checkpointing = False
|
393
|
+
|
394
|
+
def forward(
|
395
|
+
self,
|
396
|
+
hidden_states: torch.Tensor,
|
397
|
+
encoder_hidden_states: torch.Tensor,
|
398
|
+
timestep: torch.LongTensor,
|
399
|
+
original_size: torch.Tensor,
|
400
|
+
target_size: torch.Tensor,
|
401
|
+
crop_coords: torch.Tensor,
|
402
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
403
|
+
return_dict: bool = True,
|
404
|
+
attention_mask: Optional[torch.Tensor] = None,
|
405
|
+
**kwargs,
|
406
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
407
|
+
if attention_kwargs is not None:
|
408
|
+
attention_kwargs = attention_kwargs.copy()
|
409
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
410
|
+
else:
|
411
|
+
lora_scale = 1.0
|
412
|
+
|
413
|
+
if USE_PEFT_BACKEND:
|
414
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
415
|
+
scale_lora_layers(self, lora_scale)
|
416
|
+
else:
|
417
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
418
|
+
logger.warning(
|
419
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
420
|
+
)
|
421
|
+
|
422
|
+
batch_size, num_channels, height, width = hidden_states.shape
|
423
|
+
|
424
|
+
# 1. RoPE
|
425
|
+
image_rotary_emb = self.rope(hidden_states)
|
426
|
+
|
427
|
+
# 2. Patch & Timestep embeddings
|
428
|
+
p = self.config.patch_size
|
429
|
+
post_patch_height = height // p
|
430
|
+
post_patch_width = width // p
|
431
|
+
|
432
|
+
hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states)
|
433
|
+
|
434
|
+
temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
|
435
|
+
temb = F.silu(temb)
|
436
|
+
|
437
|
+
# 3. Transformer blocks
|
438
|
+
for block in self.transformer_blocks:
|
439
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
440
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
441
|
+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
442
|
+
)
|
443
|
+
else:
|
444
|
+
hidden_states, encoder_hidden_states = block(
|
445
|
+
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
|
446
|
+
)
|
447
|
+
|
448
|
+
# 4. Output norm & projection
|
449
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
450
|
+
hidden_states = self.proj_out(hidden_states)
|
451
|
+
|
452
|
+
# 5. Unpatchify
|
453
|
+
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
454
|
+
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
455
|
+
|
456
|
+
if USE_PEFT_BACKEND:
|
457
|
+
# remove `lora_scale` from each PEFT layer
|
458
|
+
unscale_lora_layers(self, lora_scale)
|
459
|
+
|
460
|
+
if not return_dict:
|
461
|
+
return (output,)
|
462
|
+
return Transformer2DModelOutput(sample=output)
|