diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +13 -10
- diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
- diffusers-0.34.0.dist-info/RECORD +639 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -548,6 +548,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
548
548
|
use_stream: bool = False,
|
549
549
|
record_stream: bool = False,
|
550
550
|
low_cpu_mem_usage=False,
|
551
|
+
offload_to_disk_path: Optional[str] = None,
|
551
552
|
) -> None:
|
552
553
|
r"""
|
553
554
|
Activates group offloading for the current model.
|
@@ -588,15 +589,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
588
589
|
f"open an issue at https://github.com/huggingface/diffusers/issues."
|
589
590
|
)
|
590
591
|
apply_group_offloading(
|
591
|
-
self,
|
592
|
-
onload_device,
|
593
|
-
offload_device,
|
594
|
-
offload_type,
|
595
|
-
num_blocks_per_group,
|
596
|
-
non_blocking,
|
597
|
-
use_stream,
|
598
|
-
record_stream,
|
592
|
+
module=self,
|
593
|
+
onload_device=onload_device,
|
594
|
+
offload_device=offload_device,
|
595
|
+
offload_type=offload_type,
|
596
|
+
num_blocks_per_group=num_blocks_per_group,
|
597
|
+
non_blocking=non_blocking,
|
598
|
+
use_stream=use_stream,
|
599
|
+
record_stream=record_stream,
|
599
600
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
601
|
+
offload_to_disk_path=offload_to_disk_path,
|
600
602
|
)
|
601
603
|
|
602
604
|
def save_pretrained(
|
@@ -787,9 +789,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
787
789
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
788
790
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
789
791
|
is not used.
|
790
|
-
torch_dtype (`
|
791
|
-
Override the default `torch.dtype` and load the model with another dtype.
|
792
|
-
dtype is automatically derived from the model's weights.
|
792
|
+
torch_dtype (`torch.dtype`, *optional*):
|
793
|
+
Override the default `torch.dtype` and load the model with another dtype.
|
793
794
|
force_download (`bool`, *optional*, defaults to `False`):
|
794
795
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
795
796
|
cached versions if they exist.
|
@@ -815,14 +816,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
815
816
|
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
816
817
|
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
817
818
|
information.
|
818
|
-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
819
|
+
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
819
820
|
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
820
821
|
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
821
822
|
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
822
823
|
|
824
|
+
Examples:
|
825
|
+
|
826
|
+
```py
|
827
|
+
>>> from diffusers import AutoModel
|
828
|
+
>>> import torch
|
829
|
+
|
830
|
+
>>> # This works.
|
831
|
+
>>> model = AutoModel.from_pretrained(
|
832
|
+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
|
833
|
+
... )
|
834
|
+
>>> # This also works (integer accelerator device ID).
|
835
|
+
>>> model = AutoModel.from_pretrained(
|
836
|
+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
|
837
|
+
... )
|
838
|
+
>>> # Specifying a supported offloading strategy like "auto" also works.
|
839
|
+
>>> model = AutoModel.from_pretrained(
|
840
|
+
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
|
841
|
+
... )
|
842
|
+
>>> # Specifying a dictionary as `device_map` also works.
|
843
|
+
>>> model = AutoModel.from_pretrained(
|
844
|
+
... "stabilityai/stable-diffusion-xl-base-1.0",
|
845
|
+
... subfolder="unet",
|
846
|
+
... device_map={"": torch.device("cuda")},
|
847
|
+
... )
|
848
|
+
```
|
849
|
+
|
823
850
|
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
824
851
|
more information about each option see [designing a device
|
825
|
-
map](https://
|
852
|
+
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
|
853
|
+
can also refer to the [Diffusers-specific
|
854
|
+
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
|
855
|
+
for more concrete examples.
|
826
856
|
max_memory (`Dict`, *optional*):
|
827
857
|
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
828
858
|
each GPU and the available CPU RAM if unset.
|
@@ -1388,7 +1418,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1388
1418
|
low_cpu_mem_usage: bool = True,
|
1389
1419
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
1390
1420
|
keep_in_fp32_modules: Optional[List[str]] = None,
|
1391
|
-
device_map: Dict[str, Union[int, str, torch.device]] = None,
|
1421
|
+
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
|
1392
1422
|
offload_state_dict: Optional[bool] = None,
|
1393
1423
|
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
1394
1424
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 HuggingFace Inc.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -237,7 +237,7 @@ class AdaLayerNormSingle(nn.Module):
|
|
237
237
|
r"""
|
238
238
|
Norm layer adaptive layer norm single (adaLN-single).
|
239
239
|
|
240
|
-
As proposed in PixArt-Alpha (see: https://
|
240
|
+
As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3).
|
241
241
|
|
242
242
|
Parameters:
|
243
243
|
embedding_dim (`int`): The size of each embedding vector.
|
@@ -510,7 +510,7 @@ else:
|
|
510
510
|
|
511
511
|
class RMSNorm(nn.Module):
|
512
512
|
r"""
|
513
|
-
RMS Norm as introduced in https://
|
513
|
+
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
|
514
514
|
|
515
515
|
Args:
|
516
516
|
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
|
@@ -600,7 +600,7 @@ class MochiRMSNorm(nn.Module):
|
|
600
600
|
|
601
601
|
class GlobalResponseNorm(nn.Module):
|
602
602
|
r"""
|
603
|
-
Global response normalization as introduced in ConvNeXt-v2 (https://
|
603
|
+
Global response normalization as introduced in ConvNeXt-v2 (https://huggingface.co/papers/2301.00808).
|
604
604
|
|
605
605
|
Args:
|
606
606
|
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
|
diffusers/models/resnet.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
# Copyright
|
2
|
-
# `TemporalConvLayer` Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
# `TemporalConvLayer` Copyright 2025 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
diffusers/models/resnet_flax.py
CHANGED
@@ -17,11 +17,15 @@ if is_torch_available():
|
|
17
17
|
from .t5_film_transformer import T5FilmDecoder
|
18
18
|
from .transformer_2d import Transformer2DModel
|
19
19
|
from .transformer_allegro import AllegroTransformer3DModel
|
20
|
+
from .transformer_chroma import ChromaTransformer2DModel
|
20
21
|
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
21
22
|
from .transformer_cogview4 import CogView4Transformer2DModel
|
23
|
+
from .transformer_cosmos import CosmosTransformer3DModel
|
22
24
|
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
23
25
|
from .transformer_flux import FluxTransformer2DModel
|
26
|
+
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
24
27
|
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
28
|
+
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
|
25
29
|
from .transformer_ltx import LTXVideoTransformer3DModel
|
26
30
|
from .transformer_lumina2 import Lumina2Transformer2DModel
|
27
31
|
from .transformer_mochi import MochiTransformer3DModel
|
@@ -29,3 +33,4 @@ if is_torch_available():
|
|
29
33
|
from .transformer_sd3 import SD3Transformer2DModel
|
30
34
|
from .transformer_temporal import TransformerTemporalModel
|
31
35
|
from .transformer_wan import WanTransformer3DModel
|
36
|
+
from .transformer_wan_vace import WanVACETransformer3DModel
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 AuraFlow Authors, The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,15 +13,15 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import Dict, Union
|
16
|
+
from typing import Any, Dict, Optional, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
20
20
|
import torch.nn.functional as F
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ...loaders import FromOriginalModelMixin
|
24
|
-
from ...utils import logging
|
23
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
25
|
from ...utils.torch_utils import maybe_allow_in_graph
|
26
26
|
from ..attention_processor import (
|
27
27
|
Attention,
|
@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module):
|
|
74
74
|
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
75
75
|
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
76
76
|
h_p, w_p = h // self.patch_size, w // self.patch_size
|
77
|
-
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
|
78
77
|
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
79
|
-
|
78
|
+
|
79
|
+
# Calculate the top-left corner indices for the centered patch grid
|
80
80
|
starth = h_max // 2 - h_p // 2
|
81
|
-
endh = starth + h_p
|
82
81
|
startw = w_max // 2 - w_p // 2
|
83
|
-
|
84
|
-
|
85
|
-
|
82
|
+
|
83
|
+
# Generate the row and column indices for the desired patch grid
|
84
|
+
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
|
85
|
+
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
|
86
|
+
|
87
|
+
# Create a 2D grid of indices
|
88
|
+
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
|
89
|
+
|
90
|
+
# Convert the 2D grid indices to flattened 1D indices
|
91
|
+
selected_indices = (row_indices * w_max + col_indices).flatten()
|
92
|
+
|
93
|
+
return selected_indices
|
86
94
|
|
87
95
|
def forward(self, latent):
|
88
96
|
batch_size, num_channels, height, width = latent.size()
|
@@ -160,14 +168,20 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
|
160
168
|
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
161
169
|
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
162
170
|
|
163
|
-
def forward(
|
171
|
+
def forward(
|
172
|
+
self,
|
173
|
+
hidden_states: torch.FloatTensor,
|
174
|
+
temb: torch.FloatTensor,
|
175
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
176
|
+
):
|
164
177
|
residual = hidden_states
|
178
|
+
attention_kwargs = attention_kwargs or {}
|
165
179
|
|
166
180
|
# Norm + Projection.
|
167
181
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
168
182
|
|
169
183
|
# Attention.
|
170
|
-
attn_output = self.attn(hidden_states=norm_hidden_states)
|
184
|
+
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
|
171
185
|
|
172
186
|
# Process attention outputs for the `hidden_states`.
|
173
187
|
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
@@ -223,10 +237,15 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
|
223
237
|
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
224
238
|
|
225
239
|
def forward(
|
226
|
-
self,
|
240
|
+
self,
|
241
|
+
hidden_states: torch.FloatTensor,
|
242
|
+
encoder_hidden_states: torch.FloatTensor,
|
243
|
+
temb: torch.FloatTensor,
|
244
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
227
245
|
):
|
228
246
|
residual = hidden_states
|
229
247
|
residual_context = encoder_hidden_states
|
248
|
+
attention_kwargs = attention_kwargs or {}
|
230
249
|
|
231
250
|
# Norm + Projection.
|
232
251
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
@@ -236,7 +255,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
|
236
255
|
|
237
256
|
# Attention.
|
238
257
|
attn_output, context_attn_output = self.attn(
|
239
|
-
hidden_states=norm_hidden_states,
|
258
|
+
hidden_states=norm_hidden_states,
|
259
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
260
|
+
**attention_kwargs,
|
240
261
|
)
|
241
262
|
|
242
263
|
# Process attention outputs for the `hidden_states`.
|
@@ -254,7 +275,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
|
254
275
|
return encoder_hidden_states, hidden_states
|
255
276
|
|
256
277
|
|
257
|
-
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
278
|
+
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
258
279
|
r"""
|
259
280
|
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
260
281
|
|
@@ -262,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
262
283
|
sample_size (`int`): The width of the latent images. This is fixed during training since
|
263
284
|
it is used to learn a number of position embeddings.
|
264
285
|
patch_size (`int`): Patch size to turn the input data into small patches.
|
265
|
-
in_channels (`int`, *optional*, defaults to
|
286
|
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
|
266
287
|
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
267
|
-
num_single_dit_layers (`int`, *optional*, defaults to
|
288
|
+
num_single_dit_layers (`int`, *optional*, defaults to 32):
|
268
289
|
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
269
290
|
representations.
|
270
|
-
attention_head_dim (`int`, *optional*, defaults to
|
271
|
-
num_attention_heads (`int`, *optional*, defaults to
|
291
|
+
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
|
292
|
+
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
|
272
293
|
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
273
294
|
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
274
|
-
out_channels (`int`, defaults to
|
275
|
-
pos_embed_max_size (`int`, defaults to
|
295
|
+
out_channels (`int`, defaults to 4): Number of output channels.
|
296
|
+
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
|
276
297
|
"""
|
277
298
|
|
278
299
|
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
@@ -338,7 +359,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
338
359
|
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
|
339
360
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
340
361
|
|
341
|
-
# https://
|
362
|
+
# https://huggingface.co/papers/2309.16588
|
342
363
|
# prevents artifacts in the attention maps
|
343
364
|
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
|
344
365
|
|
@@ -449,8 +470,24 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
449
470
|
hidden_states: torch.FloatTensor,
|
450
471
|
encoder_hidden_states: torch.FloatTensor = None,
|
451
472
|
timestep: torch.LongTensor = None,
|
473
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
452
474
|
return_dict: bool = True,
|
453
475
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
476
|
+
if attention_kwargs is not None:
|
477
|
+
attention_kwargs = attention_kwargs.copy()
|
478
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
479
|
+
else:
|
480
|
+
lora_scale = 1.0
|
481
|
+
|
482
|
+
if USE_PEFT_BACKEND:
|
483
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
484
|
+
scale_lora_layers(self, lora_scale)
|
485
|
+
else:
|
486
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
487
|
+
logger.warning(
|
488
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
489
|
+
)
|
490
|
+
|
454
491
|
height, width = hidden_states.shape[-2:]
|
455
492
|
|
456
493
|
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
@@ -474,7 +511,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
474
511
|
|
475
512
|
else:
|
476
513
|
encoder_hidden_states, hidden_states = block(
|
477
|
-
hidden_states=hidden_states,
|
514
|
+
hidden_states=hidden_states,
|
515
|
+
encoder_hidden_states=encoder_hidden_states,
|
516
|
+
temb=temb,
|
517
|
+
attention_kwargs=attention_kwargs,
|
478
518
|
)
|
479
519
|
|
480
520
|
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
@@ -491,7 +531,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
491
531
|
)
|
492
532
|
|
493
533
|
else:
|
494
|
-
combined_hidden_states = block(
|
534
|
+
combined_hidden_states = block(
|
535
|
+
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
|
536
|
+
)
|
495
537
|
|
496
538
|
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
497
539
|
|
@@ -512,6 +554,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
512
554
|
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
513
555
|
)
|
514
556
|
|
557
|
+
if USE_PEFT_BACKEND:
|
558
|
+
# remove `lora_scale` from each PEFT layer
|
559
|
+
unscale_lora_layers(self, lora_scale)
|
560
|
+
|
515
561
|
if not return_dict:
|
516
562
|
return (output,)
|
517
563
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
2
2
|
# All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -30,7 +30,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
30
30
|
|
31
31
|
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
32
32
|
r"""
|
33
|
-
A 2D Transformer model as introduced in DiT (https://
|
33
|
+
A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
|
34
34
|
|
35
35
|
Parameters:
|
36
36
|
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -308,7 +308,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|
308
308
|
activation_fn=activation_fn,
|
309
309
|
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
310
310
|
cross_attention_dim=cross_attention_dim,
|
311
|
-
qk_norm=True, # See
|
311
|
+
qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
|
312
312
|
skip=layer > num_layers // 2,
|
313
313
|
)
|
314
314
|
for layer in range(num_layers)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -18,10 +18,9 @@ import torch
|
|
18
18
|
from torch import nn
|
19
19
|
|
20
20
|
from ...configuration_utils import ConfigMixin, register_to_config
|
21
|
-
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
22
21
|
from ..attention import BasicTransformerBlock
|
23
22
|
from ..cache_utils import CacheMixin
|
24
|
-
from ..embeddings import PatchEmbed
|
23
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
25
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
26
25
|
from ..modeling_utils import ModelMixin
|
27
26
|
from ..normalization import AdaLayerNormSingle
|
@@ -31,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
|
31
30
|
_supports_gradient_checkpointing = True
|
32
31
|
|
33
32
|
"""
|
34
|
-
A 3D Transformer model for video-like data, paper: https://
|
33
|
+
A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
|
35
34
|
https://github.com/Vchitect/Latte
|
36
35
|
|
37
36
|
Parameters:
|
@@ -217,7 +216,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
|
217
216
|
)
|
218
217
|
num_patches = height * width
|
219
218
|
|
220
|
-
hidden_states = self.pos_embed(hidden_states) #
|
219
|
+
hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
|
221
220
|
|
222
221
|
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
223
222
|
timestep, embedded_timestep = self.adaln_single(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -43,7 +43,7 @@ class LuminaNextDiTBlock(nn.Module):
|
|
43
43
|
num_kv_heads (`int`):
|
44
44
|
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
|
45
45
|
multiple_of (`int`): The number of multiple of ffn layer.
|
46
|
-
ffn_dim_multiplier (`float`): The
|
46
|
+
ffn_dim_multiplier (`float`): The multiplier factor of ffn layer dimension.
|
47
47
|
norm_eps (`float`): The eps for norm layer.
|
48
48
|
qk_norm (`bool`): normalization for query and key.
|
49
49
|
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -31,8 +31,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
31
31
|
|
32
32
|
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
33
33
|
r"""
|
34
|
-
A 2D Transformer model as introduced in PixArt family of models (https://
|
35
|
-
https://
|
34
|
+
A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
|
35
|
+
https://huggingface.co/papers/2403.04692).
|
36
36
|
|
37
37
|
Parameters:
|
38
38
|
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
|
@@ -61,7 +61,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
61
61
|
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
62
62
|
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
63
63
|
product between the text embedding and image embedding as proposed in the unclip paper
|
64
|
-
https://
|
64
|
+
https://huggingface.co/papers/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
65
65
|
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
66
66
|
If None, will be set to `num_attention_heads * attention_head_dim`
|
67
67
|
embedding_proj_dim (`int`, *optional*, default to None):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -483,6 +483,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
483
483
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
484
484
|
attention_mask: Optional[torch.Tensor] = None,
|
485
485
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
486
|
+
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
486
487
|
return_dict: bool = True,
|
487
488
|
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
488
489
|
if attention_kwargs is not None:
|
@@ -546,7 +547,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
546
547
|
|
547
548
|
# 2. Transformer blocks
|
548
549
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
549
|
-
for block in self.transformer_blocks:
|
550
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
550
551
|
hidden_states = self._gradient_checkpointing_func(
|
551
552
|
block,
|
552
553
|
hidden_states,
|
@@ -557,9 +558,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
557
558
|
post_patch_height,
|
558
559
|
post_patch_width,
|
559
560
|
)
|
561
|
+
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
562
|
+
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
560
563
|
|
561
564
|
else:
|
562
|
-
for block in self.transformer_blocks:
|
565
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
563
566
|
hidden_states = block(
|
564
567
|
hidden_states,
|
565
568
|
attention_mask,
|
@@ -569,6 +572,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
569
572
|
post_patch_height,
|
570
573
|
post_patch_width,
|
571
574
|
)
|
575
|
+
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
576
|
+
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
572
577
|
|
573
578
|
# 3. Normalization
|
574
579
|
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -21,16 +21,12 @@ import torch.nn as nn
|
|
21
21
|
import torch.utils.checkpoint
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
-
from ...models.attention import FeedForward
|
25
|
-
from ...models.attention_processor import (
|
26
|
-
Attention,
|
27
|
-
AttentionProcessor,
|
28
|
-
StableAudioAttnProcessor2_0,
|
29
|
-
)
|
30
|
-
from ...models.modeling_utils import ModelMixin
|
31
|
-
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
32
24
|
from ...utils import logging
|
33
25
|
from ...utils.torch_utils import maybe_allow_in_graph
|
26
|
+
from ..attention import FeedForward
|
27
|
+
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
|
28
|
+
from ..modeling_utils import ModelMixin
|
29
|
+
from ..transformers.transformer_2d import Transformer2DModelOutput
|
34
30
|
|
35
31
|
|
36
32
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -390,7 +390,7 @@ class T5LayerNorm(nn.Module):
|
|
390
390
|
|
391
391
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
392
392
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
393
|
-
# Square Layer Normalization https://
|
393
|
+
# Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
|
394
394
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
395
395
|
# half-precision inputs is done in fp32
|
396
396
|
|
@@ -407,7 +407,7 @@ class T5LayerNorm(nn.Module):
|
|
407
407
|
class NewGELUActivation(nn.Module):
|
408
408
|
"""
|
409
409
|
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
410
|
-
the Gaussian Error Linear Units paper: https://
|
410
|
+
the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
|
411
411
|
"""
|
412
412
|
|
413
413
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|