diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,548 @@
|
|
1
|
+
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...loaders import PeftAdapterMixin
|
24
|
+
from ...loaders.single_file_model import FromOriginalModelMixin
|
25
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
26
|
+
from ..attention import LuminaFeedForward
|
27
|
+
from ..attention_processor import Attention
|
28
|
+
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
29
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
30
|
+
from ..modeling_utils import ModelMixin
|
31
|
+
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
32
|
+
|
33
|
+
|
34
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
|
+
|
36
|
+
|
37
|
+
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
hidden_size: int = 4096,
|
41
|
+
cap_feat_dim: int = 2048,
|
42
|
+
frequency_embedding_size: int = 256,
|
43
|
+
norm_eps: float = 1e-5,
|
44
|
+
) -> None:
|
45
|
+
super().__init__()
|
46
|
+
|
47
|
+
self.time_proj = Timesteps(
|
48
|
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
49
|
+
)
|
50
|
+
|
51
|
+
self.timestep_embedder = TimestepEmbedding(
|
52
|
+
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
53
|
+
)
|
54
|
+
|
55
|
+
self.caption_embedder = nn.Sequential(
|
56
|
+
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
57
|
+
)
|
58
|
+
|
59
|
+
def forward(
|
60
|
+
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
61
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
62
|
+
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
63
|
+
time_embed = self.timestep_embedder(timestep_proj)
|
64
|
+
caption_embed = self.caption_embedder(encoder_hidden_states)
|
65
|
+
return time_embed, caption_embed
|
66
|
+
|
67
|
+
|
68
|
+
class Lumina2AttnProcessor2_0:
|
69
|
+
r"""
|
70
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
71
|
+
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(self):
|
75
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
76
|
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
77
|
+
|
78
|
+
def __call__(
|
79
|
+
self,
|
80
|
+
attn: Attention,
|
81
|
+
hidden_states: torch.Tensor,
|
82
|
+
encoder_hidden_states: torch.Tensor,
|
83
|
+
attention_mask: Optional[torch.Tensor] = None,
|
84
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
85
|
+
base_sequence_length: Optional[int] = None,
|
86
|
+
) -> torch.Tensor:
|
87
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
88
|
+
|
89
|
+
# Get Query-Key-Value Pair
|
90
|
+
query = attn.to_q(hidden_states)
|
91
|
+
key = attn.to_k(encoder_hidden_states)
|
92
|
+
value = attn.to_v(encoder_hidden_states)
|
93
|
+
|
94
|
+
query_dim = query.shape[-1]
|
95
|
+
inner_dim = key.shape[-1]
|
96
|
+
head_dim = query_dim // attn.heads
|
97
|
+
dtype = query.dtype
|
98
|
+
|
99
|
+
# Get key-value heads
|
100
|
+
kv_heads = inner_dim // head_dim
|
101
|
+
|
102
|
+
query = query.view(batch_size, -1, attn.heads, head_dim)
|
103
|
+
key = key.view(batch_size, -1, kv_heads, head_dim)
|
104
|
+
value = value.view(batch_size, -1, kv_heads, head_dim)
|
105
|
+
|
106
|
+
# Apply Query-Key Norm if needed
|
107
|
+
if attn.norm_q is not None:
|
108
|
+
query = attn.norm_q(query)
|
109
|
+
if attn.norm_k is not None:
|
110
|
+
key = attn.norm_k(key)
|
111
|
+
|
112
|
+
# Apply RoPE if needed
|
113
|
+
if image_rotary_emb is not None:
|
114
|
+
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
115
|
+
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
116
|
+
|
117
|
+
query, key = query.to(dtype), key.to(dtype)
|
118
|
+
|
119
|
+
# Apply proportional attention if true
|
120
|
+
if base_sequence_length is not None:
|
121
|
+
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
122
|
+
else:
|
123
|
+
softmax_scale = attn.scale
|
124
|
+
|
125
|
+
# perform Grouped-qurey Attention (GQA)
|
126
|
+
n_rep = attn.heads // kv_heads
|
127
|
+
if n_rep >= 1:
|
128
|
+
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
129
|
+
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
130
|
+
|
131
|
+
# scaled_dot_product_attention expects attention_mask shape to be
|
132
|
+
# (batch, heads, source_length, target_length)
|
133
|
+
if attention_mask is not None:
|
134
|
+
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
135
|
+
|
136
|
+
query = query.transpose(1, 2)
|
137
|
+
key = key.transpose(1, 2)
|
138
|
+
value = value.transpose(1, 2)
|
139
|
+
|
140
|
+
hidden_states = F.scaled_dot_product_attention(
|
141
|
+
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
142
|
+
)
|
143
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
144
|
+
hidden_states = hidden_states.type_as(query)
|
145
|
+
|
146
|
+
# linear proj
|
147
|
+
hidden_states = attn.to_out[0](hidden_states)
|
148
|
+
hidden_states = attn.to_out[1](hidden_states)
|
149
|
+
return hidden_states
|
150
|
+
|
151
|
+
|
152
|
+
class Lumina2TransformerBlock(nn.Module):
|
153
|
+
def __init__(
|
154
|
+
self,
|
155
|
+
dim: int,
|
156
|
+
num_attention_heads: int,
|
157
|
+
num_kv_heads: int,
|
158
|
+
multiple_of: int,
|
159
|
+
ffn_dim_multiplier: float,
|
160
|
+
norm_eps: float,
|
161
|
+
modulation: bool = True,
|
162
|
+
) -> None:
|
163
|
+
super().__init__()
|
164
|
+
self.head_dim = dim // num_attention_heads
|
165
|
+
self.modulation = modulation
|
166
|
+
|
167
|
+
self.attn = Attention(
|
168
|
+
query_dim=dim,
|
169
|
+
cross_attention_dim=None,
|
170
|
+
dim_head=dim // num_attention_heads,
|
171
|
+
qk_norm="rms_norm",
|
172
|
+
heads=num_attention_heads,
|
173
|
+
kv_heads=num_kv_heads,
|
174
|
+
eps=1e-5,
|
175
|
+
bias=False,
|
176
|
+
out_bias=False,
|
177
|
+
processor=Lumina2AttnProcessor2_0(),
|
178
|
+
)
|
179
|
+
|
180
|
+
self.feed_forward = LuminaFeedForward(
|
181
|
+
dim=dim,
|
182
|
+
inner_dim=4 * dim,
|
183
|
+
multiple_of=multiple_of,
|
184
|
+
ffn_dim_multiplier=ffn_dim_multiplier,
|
185
|
+
)
|
186
|
+
|
187
|
+
if modulation:
|
188
|
+
self.norm1 = LuminaRMSNormZero(
|
189
|
+
embedding_dim=dim,
|
190
|
+
norm_eps=norm_eps,
|
191
|
+
norm_elementwise_affine=True,
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
195
|
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
196
|
+
|
197
|
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
198
|
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
199
|
+
|
200
|
+
def forward(
|
201
|
+
self,
|
202
|
+
hidden_states: torch.Tensor,
|
203
|
+
attention_mask: torch.Tensor,
|
204
|
+
image_rotary_emb: torch.Tensor,
|
205
|
+
temb: Optional[torch.Tensor] = None,
|
206
|
+
) -> torch.Tensor:
|
207
|
+
if self.modulation:
|
208
|
+
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
209
|
+
attn_output = self.attn(
|
210
|
+
hidden_states=norm_hidden_states,
|
211
|
+
encoder_hidden_states=norm_hidden_states,
|
212
|
+
attention_mask=attention_mask,
|
213
|
+
image_rotary_emb=image_rotary_emb,
|
214
|
+
)
|
215
|
+
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
216
|
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
217
|
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
218
|
+
else:
|
219
|
+
norm_hidden_states = self.norm1(hidden_states)
|
220
|
+
attn_output = self.attn(
|
221
|
+
hidden_states=norm_hidden_states,
|
222
|
+
encoder_hidden_states=norm_hidden_states,
|
223
|
+
attention_mask=attention_mask,
|
224
|
+
image_rotary_emb=image_rotary_emb,
|
225
|
+
)
|
226
|
+
hidden_states = hidden_states + self.norm2(attn_output)
|
227
|
+
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
228
|
+
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
229
|
+
|
230
|
+
return hidden_states
|
231
|
+
|
232
|
+
|
233
|
+
class Lumina2RotaryPosEmbed(nn.Module):
|
234
|
+
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
235
|
+
super().__init__()
|
236
|
+
self.theta = theta
|
237
|
+
self.axes_dim = axes_dim
|
238
|
+
self.axes_lens = axes_lens
|
239
|
+
self.patch_size = patch_size
|
240
|
+
|
241
|
+
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
242
|
+
|
243
|
+
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
244
|
+
freqs_cis = []
|
245
|
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
246
|
+
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
247
|
+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
|
248
|
+
freqs_cis.append(emb)
|
249
|
+
return freqs_cis
|
250
|
+
|
251
|
+
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
252
|
+
device = ids.device
|
253
|
+
if ids.device.type == "mps":
|
254
|
+
ids = ids.to("cpu")
|
255
|
+
|
256
|
+
result = []
|
257
|
+
for i in range(len(self.axes_dim)):
|
258
|
+
freqs = self.freqs_cis[i].to(ids.device)
|
259
|
+
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
260
|
+
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
261
|
+
return torch.cat(result, dim=-1).to(device)
|
262
|
+
|
263
|
+
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
264
|
+
batch_size, channels, height, width = hidden_states.shape
|
265
|
+
p = self.patch_size
|
266
|
+
post_patch_height, post_patch_width = height // p, width // p
|
267
|
+
image_seq_len = post_patch_height * post_patch_width
|
268
|
+
device = hidden_states.device
|
269
|
+
|
270
|
+
encoder_seq_len = attention_mask.shape[1]
|
271
|
+
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
272
|
+
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
|
273
|
+
max_seq_len = max(seq_lengths)
|
274
|
+
|
275
|
+
# Create position IDs
|
276
|
+
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
277
|
+
|
278
|
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
279
|
+
# add caption position ids
|
280
|
+
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
|
281
|
+
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
|
282
|
+
|
283
|
+
# add image position ids
|
284
|
+
row_ids = (
|
285
|
+
torch.arange(post_patch_height, dtype=torch.int32, device=device)
|
286
|
+
.view(-1, 1)
|
287
|
+
.repeat(1, post_patch_width)
|
288
|
+
.flatten()
|
289
|
+
)
|
290
|
+
col_ids = (
|
291
|
+
torch.arange(post_patch_width, dtype=torch.int32, device=device)
|
292
|
+
.view(1, -1)
|
293
|
+
.repeat(post_patch_height, 1)
|
294
|
+
.flatten()
|
295
|
+
)
|
296
|
+
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
|
297
|
+
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
|
298
|
+
|
299
|
+
# Get combined rotary embeddings
|
300
|
+
freqs_cis = self._get_freqs_cis(position_ids)
|
301
|
+
|
302
|
+
# create separate rotary embeddings for captions and images
|
303
|
+
cap_freqs_cis = torch.zeros(
|
304
|
+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
305
|
+
)
|
306
|
+
img_freqs_cis = torch.zeros(
|
307
|
+
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
|
308
|
+
)
|
309
|
+
|
310
|
+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
|
311
|
+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
|
312
|
+
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
|
313
|
+
|
314
|
+
# image patch embeddings
|
315
|
+
hidden_states = (
|
316
|
+
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
|
317
|
+
.permute(0, 2, 4, 3, 5, 1)
|
318
|
+
.flatten(3)
|
319
|
+
.flatten(1, 2)
|
320
|
+
)
|
321
|
+
|
322
|
+
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
|
323
|
+
|
324
|
+
|
325
|
+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
326
|
+
r"""
|
327
|
+
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
328
|
+
|
329
|
+
Parameters:
|
330
|
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
331
|
+
it is used to learn a number of position embeddings.
|
332
|
+
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
333
|
+
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
334
|
+
in_channels (`int`, *optional*, defaults to 4):
|
335
|
+
The number of input channels for the model. Typically, this matches the number of channels in the input
|
336
|
+
images.
|
337
|
+
hidden_size (`int`, *optional*, defaults to 4096):
|
338
|
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
339
|
+
hidden representations.
|
340
|
+
num_layers (`int`, *optional*, default to 32):
|
341
|
+
The number of layers in the model. This defines the depth of the neural network.
|
342
|
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
343
|
+
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
344
|
+
mechanisms are used.
|
345
|
+
num_kv_heads (`int`, *optional*, defaults to 8):
|
346
|
+
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
347
|
+
If None, it defaults to num_attention_heads.
|
348
|
+
multiple_of (`int`, *optional*, defaults to 256):
|
349
|
+
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
350
|
+
configurations.
|
351
|
+
ffn_dim_multiplier (`float`, *optional*):
|
352
|
+
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
353
|
+
the model configuration.
|
354
|
+
norm_eps (`float`, *optional*, defaults to 1e-5):
|
355
|
+
A small value added to the denominator for numerical stability in normalization layers.
|
356
|
+
scaling_factor (`float`, *optional*, defaults to 1.0):
|
357
|
+
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
358
|
+
overall scale of the model's operations.
|
359
|
+
"""
|
360
|
+
|
361
|
+
_supports_gradient_checkpointing = True
|
362
|
+
_no_split_modules = ["Lumina2TransformerBlock"]
|
363
|
+
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
364
|
+
|
365
|
+
@register_to_config
|
366
|
+
def __init__(
|
367
|
+
self,
|
368
|
+
sample_size: int = 128,
|
369
|
+
patch_size: int = 2,
|
370
|
+
in_channels: int = 16,
|
371
|
+
out_channels: Optional[int] = None,
|
372
|
+
hidden_size: int = 2304,
|
373
|
+
num_layers: int = 26,
|
374
|
+
num_refiner_layers: int = 2,
|
375
|
+
num_attention_heads: int = 24,
|
376
|
+
num_kv_heads: int = 8,
|
377
|
+
multiple_of: int = 256,
|
378
|
+
ffn_dim_multiplier: Optional[float] = None,
|
379
|
+
norm_eps: float = 1e-5,
|
380
|
+
scaling_factor: float = 1.0,
|
381
|
+
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
382
|
+
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
383
|
+
cap_feat_dim: int = 1024,
|
384
|
+
) -> None:
|
385
|
+
super().__init__()
|
386
|
+
self.out_channels = out_channels or in_channels
|
387
|
+
|
388
|
+
# 1. Positional, patch & conditional embeddings
|
389
|
+
self.rope_embedder = Lumina2RotaryPosEmbed(
|
390
|
+
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
391
|
+
)
|
392
|
+
|
393
|
+
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
394
|
+
|
395
|
+
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
396
|
+
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
397
|
+
)
|
398
|
+
|
399
|
+
# 2. Noise and context refinement blocks
|
400
|
+
self.noise_refiner = nn.ModuleList(
|
401
|
+
[
|
402
|
+
Lumina2TransformerBlock(
|
403
|
+
hidden_size,
|
404
|
+
num_attention_heads,
|
405
|
+
num_kv_heads,
|
406
|
+
multiple_of,
|
407
|
+
ffn_dim_multiplier,
|
408
|
+
norm_eps,
|
409
|
+
modulation=True,
|
410
|
+
)
|
411
|
+
for _ in range(num_refiner_layers)
|
412
|
+
]
|
413
|
+
)
|
414
|
+
|
415
|
+
self.context_refiner = nn.ModuleList(
|
416
|
+
[
|
417
|
+
Lumina2TransformerBlock(
|
418
|
+
hidden_size,
|
419
|
+
num_attention_heads,
|
420
|
+
num_kv_heads,
|
421
|
+
multiple_of,
|
422
|
+
ffn_dim_multiplier,
|
423
|
+
norm_eps,
|
424
|
+
modulation=False,
|
425
|
+
)
|
426
|
+
for _ in range(num_refiner_layers)
|
427
|
+
]
|
428
|
+
)
|
429
|
+
|
430
|
+
# 3. Transformer blocks
|
431
|
+
self.layers = nn.ModuleList(
|
432
|
+
[
|
433
|
+
Lumina2TransformerBlock(
|
434
|
+
hidden_size,
|
435
|
+
num_attention_heads,
|
436
|
+
num_kv_heads,
|
437
|
+
multiple_of,
|
438
|
+
ffn_dim_multiplier,
|
439
|
+
norm_eps,
|
440
|
+
modulation=True,
|
441
|
+
)
|
442
|
+
for _ in range(num_layers)
|
443
|
+
]
|
444
|
+
)
|
445
|
+
|
446
|
+
# 4. Output norm & projection
|
447
|
+
self.norm_out = LuminaLayerNormContinuous(
|
448
|
+
embedding_dim=hidden_size,
|
449
|
+
conditioning_embedding_dim=min(hidden_size, 1024),
|
450
|
+
elementwise_affine=False,
|
451
|
+
eps=1e-6,
|
452
|
+
bias=True,
|
453
|
+
out_dim=patch_size * patch_size * self.out_channels,
|
454
|
+
)
|
455
|
+
|
456
|
+
self.gradient_checkpointing = False
|
457
|
+
|
458
|
+
def forward(
|
459
|
+
self,
|
460
|
+
hidden_states: torch.Tensor,
|
461
|
+
timestep: torch.Tensor,
|
462
|
+
encoder_hidden_states: torch.Tensor,
|
463
|
+
encoder_attention_mask: torch.Tensor,
|
464
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
465
|
+
return_dict: bool = True,
|
466
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
467
|
+
if attention_kwargs is not None:
|
468
|
+
attention_kwargs = attention_kwargs.copy()
|
469
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
470
|
+
else:
|
471
|
+
lora_scale = 1.0
|
472
|
+
|
473
|
+
if USE_PEFT_BACKEND:
|
474
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
475
|
+
scale_lora_layers(self, lora_scale)
|
476
|
+
else:
|
477
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
478
|
+
logger.warning(
|
479
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
480
|
+
)
|
481
|
+
|
482
|
+
# 1. Condition, positional & patch embedding
|
483
|
+
batch_size, _, height, width = hidden_states.shape
|
484
|
+
|
485
|
+
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
486
|
+
|
487
|
+
(
|
488
|
+
hidden_states,
|
489
|
+
context_rotary_emb,
|
490
|
+
noise_rotary_emb,
|
491
|
+
rotary_emb,
|
492
|
+
encoder_seq_lengths,
|
493
|
+
seq_lengths,
|
494
|
+
) = self.rope_embedder(hidden_states, encoder_attention_mask)
|
495
|
+
|
496
|
+
hidden_states = self.x_embedder(hidden_states)
|
497
|
+
|
498
|
+
# 2. Context & noise refinement
|
499
|
+
for layer in self.context_refiner:
|
500
|
+
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
|
501
|
+
|
502
|
+
for layer in self.noise_refiner:
|
503
|
+
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
|
504
|
+
|
505
|
+
# 3. Joint Transformer blocks
|
506
|
+
max_seq_len = max(seq_lengths)
|
507
|
+
use_mask = len(set(seq_lengths)) > 1
|
508
|
+
|
509
|
+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
510
|
+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
511
|
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
512
|
+
attention_mask[i, :seq_len] = True
|
513
|
+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
|
514
|
+
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
|
515
|
+
|
516
|
+
hidden_states = joint_hidden_states
|
517
|
+
|
518
|
+
for layer in self.layers:
|
519
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
520
|
+
hidden_states = self._gradient_checkpointing_func(
|
521
|
+
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
|
522
|
+
)
|
523
|
+
else:
|
524
|
+
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
|
525
|
+
|
526
|
+
# 4. Output norm & projection
|
527
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
528
|
+
|
529
|
+
# 5. Unpatchify
|
530
|
+
p = self.config.patch_size
|
531
|
+
output = []
|
532
|
+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
|
533
|
+
output.append(
|
534
|
+
hidden_states[i][encoder_seq_len:seq_len]
|
535
|
+
.view(height // p, width // p, p, p, self.out_channels)
|
536
|
+
.permute(4, 0, 2, 1, 3)
|
537
|
+
.flatten(3, 4)
|
538
|
+
.flatten(1, 2)
|
539
|
+
)
|
540
|
+
output = torch.stack(output, dim=0)
|
541
|
+
|
542
|
+
if USE_PEFT_BACKEND:
|
543
|
+
# remove `lora_scale` from each PEFT layer
|
544
|
+
unscale_lora_layers(self, lora_scale)
|
545
|
+
|
546
|
+
if not return_dict:
|
547
|
+
return (output,)
|
548
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -21,10 +21,11 @@ import torch.nn as nn
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
23
|
from ...loaders.single_file_model import FromOriginalModelMixin
|
24
|
-
from ...utils import USE_PEFT_BACKEND,
|
24
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
25
|
from ...utils.torch_utils import maybe_allow_in_graph
|
26
26
|
from ..attention import FeedForward
|
27
27
|
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
28
|
+
from ..cache_utils import CacheMixin
|
28
29
|
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
29
30
|
from ..modeling_outputs import Transformer2DModelOutput
|
30
31
|
from ..modeling_utils import ModelMixin
|
@@ -305,7 +306,7 @@ class MochiRoPE(nn.Module):
|
|
305
306
|
|
306
307
|
|
307
308
|
@maybe_allow_in_graph
|
308
|
-
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
309
|
+
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
309
310
|
r"""
|
310
311
|
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
311
312
|
|
@@ -336,6 +337,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
|
336
337
|
|
337
338
|
_supports_gradient_checkpointing = True
|
338
339
|
_no_split_modules = ["MochiTransformerBlock"]
|
340
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
339
341
|
|
340
342
|
@register_to_config
|
341
343
|
def __init__(
|
@@ -402,10 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
|
402
404
|
|
403
405
|
self.gradient_checkpointing = False
|
404
406
|
|
405
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
406
|
-
if hasattr(module, "gradient_checkpointing"):
|
407
|
-
module.gradient_checkpointing = value
|
408
|
-
|
409
407
|
def forward(
|
410
408
|
self,
|
411
409
|
hidden_states: torch.Tensor,
|
@@ -458,22 +456,13 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
|
458
456
|
|
459
457
|
for i, block in enumerate(self.transformer_blocks):
|
460
458
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
461
|
-
|
462
|
-
|
463
|
-
def custom_forward(*inputs):
|
464
|
-
return module(*inputs)
|
465
|
-
|
466
|
-
return custom_forward
|
467
|
-
|
468
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
469
|
-
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
470
|
-
create_custom_forward(block),
|
459
|
+
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
460
|
+
block,
|
471
461
|
hidden_states,
|
472
462
|
encoder_hidden_states,
|
473
463
|
temb,
|
474
464
|
encoder_attention_mask,
|
475
465
|
image_rotary_emb,
|
476
|
-
**ckpt_kwargs,
|
477
466
|
)
|
478
467
|
else:
|
479
468
|
hidden_states, encoder_hidden_states = block(
|