diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +13 -10
- diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
- diffusers-0.34.0.dist-info/RECORD +639 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,12 +20,12 @@ import torch.nn as nn
|
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
|
-
from ...models.attention_processor import AttentionProcessor
|
24
|
-
from ...models.modeling_utils import ModelMixin
|
25
23
|
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ..attention_processor import AttentionProcessor
|
26
25
|
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
27
26
|
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
28
27
|
from ..modeling_outputs import Transformer2DModelOutput
|
28
|
+
from ..modeling_utils import ModelMixin
|
29
29
|
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
30
30
|
|
31
31
|
|
@@ -430,7 +430,7 @@ class FluxMultiControlNetModel(ModelMixin):
|
|
430
430
|
) -> Union[FluxControlNetOutput, Tuple]:
|
431
431
|
# ControlNet-Union with multiple conditions
|
432
432
|
# only load one ControlNet for saving memories
|
433
|
-
if len(self.nets) == 1
|
433
|
+
if len(self.nets) == 1:
|
434
434
|
controlnet = self.nets[0]
|
435
435
|
|
436
436
|
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
|
@@ -454,17 +454,18 @@ class FluxMultiControlNetModel(ModelMixin):
|
|
454
454
|
control_block_samples = block_samples
|
455
455
|
control_single_block_samples = single_block_samples
|
456
456
|
else:
|
457
|
-
control_block_samples
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
control_single_block_samples
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
457
|
+
if block_samples is not None and control_block_samples is not None:
|
458
|
+
control_block_samples = [
|
459
|
+
control_block_sample + block_sample
|
460
|
+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
|
461
|
+
]
|
462
|
+
if single_block_samples is not None and control_single_block_samples is not None:
|
463
|
+
control_single_block_samples = [
|
464
|
+
control_single_block_sample + block_sample
|
465
|
+
for control_single_block_sample, block_sample in zip(
|
466
|
+
control_single_block_samples, single_block_samples
|
467
|
+
)
|
468
|
+
]
|
468
469
|
|
469
470
|
# Regular Multi-ControlNets
|
470
471
|
# load all ControlNets into memories
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -103,7 +103,7 @@ class HunyuanDiT2DControlNetModel(ModelMixin, ConfigMixin):
|
|
103
103
|
activation_fn=activation_fn,
|
104
104
|
ff_inner_dim=int(self.inner_dim * mlp_ratio),
|
105
105
|
cross_attention_dim=cross_attention_dim,
|
106
|
-
qk_norm=True, # See
|
106
|
+
qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
|
107
107
|
skip=False, # always False as it is the first half of the model
|
108
108
|
)
|
109
109
|
for layer in range(transformer_num_layers // 2 - 1)
|
@@ -0,0 +1,290 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from torch import nn
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ..attention_processor import AttentionProcessor
|
25
|
+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
26
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
27
|
+
from ..modeling_utils import ModelMixin
|
28
|
+
from ..normalization import AdaLayerNormSingle, RMSNorm
|
29
|
+
from ..transformers.sana_transformer import SanaTransformerBlock
|
30
|
+
from .controlnet import zero_module
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class SanaControlNetOutput(BaseOutput):
|
38
|
+
controlnet_block_samples: Tuple[torch.Tensor]
|
39
|
+
|
40
|
+
|
41
|
+
class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
42
|
+
_supports_gradient_checkpointing = True
|
43
|
+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
|
44
|
+
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
|
45
|
+
|
46
|
+
@register_to_config
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
in_channels: int = 32,
|
50
|
+
out_channels: Optional[int] = 32,
|
51
|
+
num_attention_heads: int = 70,
|
52
|
+
attention_head_dim: int = 32,
|
53
|
+
num_layers: int = 7,
|
54
|
+
num_cross_attention_heads: Optional[int] = 20,
|
55
|
+
cross_attention_head_dim: Optional[int] = 112,
|
56
|
+
cross_attention_dim: Optional[int] = 2240,
|
57
|
+
caption_channels: int = 2304,
|
58
|
+
mlp_ratio: float = 2.5,
|
59
|
+
dropout: float = 0.0,
|
60
|
+
attention_bias: bool = False,
|
61
|
+
sample_size: int = 32,
|
62
|
+
patch_size: int = 1,
|
63
|
+
norm_elementwise_affine: bool = False,
|
64
|
+
norm_eps: float = 1e-6,
|
65
|
+
interpolation_scale: Optional[int] = None,
|
66
|
+
) -> None:
|
67
|
+
super().__init__()
|
68
|
+
|
69
|
+
out_channels = out_channels or in_channels
|
70
|
+
inner_dim = num_attention_heads * attention_head_dim
|
71
|
+
|
72
|
+
# 1. Patch Embedding
|
73
|
+
self.patch_embed = PatchEmbed(
|
74
|
+
height=sample_size,
|
75
|
+
width=sample_size,
|
76
|
+
patch_size=patch_size,
|
77
|
+
in_channels=in_channels,
|
78
|
+
embed_dim=inner_dim,
|
79
|
+
interpolation_scale=interpolation_scale,
|
80
|
+
pos_embed_type="sincos" if interpolation_scale is not None else None,
|
81
|
+
)
|
82
|
+
|
83
|
+
# 2. Additional condition embeddings
|
84
|
+
self.time_embed = AdaLayerNormSingle(inner_dim)
|
85
|
+
|
86
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
87
|
+
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
88
|
+
|
89
|
+
# 3. Transformer blocks
|
90
|
+
self.transformer_blocks = nn.ModuleList(
|
91
|
+
[
|
92
|
+
SanaTransformerBlock(
|
93
|
+
inner_dim,
|
94
|
+
num_attention_heads,
|
95
|
+
attention_head_dim,
|
96
|
+
dropout=dropout,
|
97
|
+
num_cross_attention_heads=num_cross_attention_heads,
|
98
|
+
cross_attention_head_dim=cross_attention_head_dim,
|
99
|
+
cross_attention_dim=cross_attention_dim,
|
100
|
+
attention_bias=attention_bias,
|
101
|
+
norm_elementwise_affine=norm_elementwise_affine,
|
102
|
+
norm_eps=norm_eps,
|
103
|
+
mlp_ratio=mlp_ratio,
|
104
|
+
)
|
105
|
+
for _ in range(num_layers)
|
106
|
+
]
|
107
|
+
)
|
108
|
+
|
109
|
+
# controlnet_blocks
|
110
|
+
self.controlnet_blocks = nn.ModuleList([])
|
111
|
+
|
112
|
+
self.input_block = zero_module(nn.Linear(inner_dim, inner_dim))
|
113
|
+
for _ in range(len(self.transformer_blocks)):
|
114
|
+
controlnet_block = nn.Linear(inner_dim, inner_dim)
|
115
|
+
controlnet_block = zero_module(controlnet_block)
|
116
|
+
self.controlnet_blocks.append(controlnet_block)
|
117
|
+
|
118
|
+
self.gradient_checkpointing = False
|
119
|
+
|
120
|
+
@property
|
121
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
122
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
123
|
+
r"""
|
124
|
+
Returns:
|
125
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
126
|
+
indexed by its weight name.
|
127
|
+
"""
|
128
|
+
# set recursively
|
129
|
+
processors = {}
|
130
|
+
|
131
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
132
|
+
if hasattr(module, "get_processor"):
|
133
|
+
processors[f"{name}.processor"] = module.get_processor()
|
134
|
+
|
135
|
+
for sub_name, child in module.named_children():
|
136
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
137
|
+
|
138
|
+
return processors
|
139
|
+
|
140
|
+
for name, module in self.named_children():
|
141
|
+
fn_recursive_add_processors(name, module, processors)
|
142
|
+
|
143
|
+
return processors
|
144
|
+
|
145
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
146
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
147
|
+
r"""
|
148
|
+
Sets the attention processor to use to compute attention.
|
149
|
+
|
150
|
+
Parameters:
|
151
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
152
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
153
|
+
for **all** `Attention` layers.
|
154
|
+
|
155
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
156
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
157
|
+
|
158
|
+
"""
|
159
|
+
count = len(self.attn_processors.keys())
|
160
|
+
|
161
|
+
if isinstance(processor, dict) and len(processor) != count:
|
162
|
+
raise ValueError(
|
163
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
164
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
165
|
+
)
|
166
|
+
|
167
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
168
|
+
if hasattr(module, "set_processor"):
|
169
|
+
if not isinstance(processor, dict):
|
170
|
+
module.set_processor(processor)
|
171
|
+
else:
|
172
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
173
|
+
|
174
|
+
for sub_name, child in module.named_children():
|
175
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
176
|
+
|
177
|
+
for name, module in self.named_children():
|
178
|
+
fn_recursive_attn_processor(name, module, processor)
|
179
|
+
|
180
|
+
def forward(
|
181
|
+
self,
|
182
|
+
hidden_states: torch.Tensor,
|
183
|
+
encoder_hidden_states: torch.Tensor,
|
184
|
+
timestep: torch.LongTensor,
|
185
|
+
controlnet_cond: torch.Tensor,
|
186
|
+
conditioning_scale: float = 1.0,
|
187
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
188
|
+
attention_mask: Optional[torch.Tensor] = None,
|
189
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
190
|
+
return_dict: bool = True,
|
191
|
+
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
192
|
+
if attention_kwargs is not None:
|
193
|
+
attention_kwargs = attention_kwargs.copy()
|
194
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
195
|
+
else:
|
196
|
+
lora_scale = 1.0
|
197
|
+
|
198
|
+
if USE_PEFT_BACKEND:
|
199
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
200
|
+
scale_lora_layers(self, lora_scale)
|
201
|
+
else:
|
202
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
203
|
+
logger.warning(
|
204
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
205
|
+
)
|
206
|
+
|
207
|
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
208
|
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
209
|
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
210
|
+
# expects mask of shape:
|
211
|
+
# [batch, key_tokens]
|
212
|
+
# adds singleton query_tokens dimension:
|
213
|
+
# [batch, 1, key_tokens]
|
214
|
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
215
|
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
216
|
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
217
|
+
if attention_mask is not None and attention_mask.ndim == 2:
|
218
|
+
# assume that mask is expressed as:
|
219
|
+
# (1 = keep, 0 = discard)
|
220
|
+
# convert mask into a bias that can be added to attention scores:
|
221
|
+
# (keep = +0, discard = -10000.0)
|
222
|
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
223
|
+
attention_mask = attention_mask.unsqueeze(1)
|
224
|
+
|
225
|
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
226
|
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
227
|
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
228
|
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
229
|
+
|
230
|
+
# 1. Input
|
231
|
+
batch_size, num_channels, height, width = hidden_states.shape
|
232
|
+
p = self.config.patch_size
|
233
|
+
post_patch_height, post_patch_width = height // p, width // p
|
234
|
+
|
235
|
+
hidden_states = self.patch_embed(hidden_states)
|
236
|
+
hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype)))
|
237
|
+
|
238
|
+
timestep, embedded_timestep = self.time_embed(
|
239
|
+
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
240
|
+
)
|
241
|
+
|
242
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
243
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
244
|
+
|
245
|
+
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
246
|
+
|
247
|
+
# 2. Transformer blocks
|
248
|
+
block_res_samples = ()
|
249
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
250
|
+
for block in self.transformer_blocks:
|
251
|
+
hidden_states = self._gradient_checkpointing_func(
|
252
|
+
block,
|
253
|
+
hidden_states,
|
254
|
+
attention_mask,
|
255
|
+
encoder_hidden_states,
|
256
|
+
encoder_attention_mask,
|
257
|
+
timestep,
|
258
|
+
post_patch_height,
|
259
|
+
post_patch_width,
|
260
|
+
)
|
261
|
+
block_res_samples = block_res_samples + (hidden_states,)
|
262
|
+
else:
|
263
|
+
for block in self.transformer_blocks:
|
264
|
+
hidden_states = block(
|
265
|
+
hidden_states,
|
266
|
+
attention_mask,
|
267
|
+
encoder_hidden_states,
|
268
|
+
encoder_attention_mask,
|
269
|
+
timestep,
|
270
|
+
post_patch_height,
|
271
|
+
post_patch_width,
|
272
|
+
)
|
273
|
+
block_res_samples = block_res_samples + (hidden_states,)
|
274
|
+
|
275
|
+
# 3. ControlNet blocks
|
276
|
+
controlnet_block_res_samples = ()
|
277
|
+
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
278
|
+
block_res_sample = controlnet_block(block_res_sample)
|
279
|
+
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
280
|
+
|
281
|
+
if USE_PEFT_BACKEND:
|
282
|
+
# remove `lora_scale` from each PEFT layer
|
283
|
+
unscale_lora_layers(self, lora_scale)
|
284
|
+
|
285
|
+
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
286
|
+
|
287
|
+
if not return_dict:
|
288
|
+
return (controlnet_block_res_samples,)
|
289
|
+
|
290
|
+
return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -96,7 +96,7 @@ class SparseControlNetConditioningEmbedding(nn.Module):
|
|
96
96
|
class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
97
97
|
"""
|
98
98
|
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
|
99
|
-
Models](https://
|
99
|
+
Models](https://huggingface.co/papers/2311.16933).
|
100
100
|
|
101
101
|
Args:
|
102
102
|
in_channels (`int`, defaults to 4):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
734
734
|
unet (`UNet2DConditionModel`):
|
735
735
|
The UNet model we want to control.
|
736
736
|
controlnet (`ControlNetXSAdapter`):
|
737
|
-
The
|
737
|
+
The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
|
738
738
|
adapter will be created.
|
739
739
|
size_ratio (float, *optional*, defaults to `None`):
|
740
|
-
Used to
|
740
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
741
741
|
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
742
|
-
Used to
|
742
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
|
743
743
|
where this parameter is called `block_out_channels`.
|
744
744
|
time_embedding_mix (`float`, *optional*, defaults to None):
|
745
|
-
Used to
|
745
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
746
746
|
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
|
747
|
-
Passed to the `init` of the new
|
747
|
+
Passed to the `init` of the new controlnet if no controlnet was given.
|
748
748
|
"""
|
749
749
|
if controlnet is None:
|
750
750
|
controlnet = ControlNetXSAdapter.from_unet(
|
@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
942
942
|
|
943
943
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
944
944
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
945
|
-
r"""Enables the FreeU mechanism from https://
|
945
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
946
946
|
|
947
947
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
948
948
|
|
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
|
-
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
|
8
|
-
from ...models.modeling_utils import ModelMixin
|
9
7
|
from ...utils import logging
|
8
|
+
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
|
9
|
+
from ..modeling_utils import ModelMixin
|
10
10
|
|
11
11
|
|
12
12
|
logger = logging.get_logger(__name__)
|
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
|
|
130
130
|
A path to a *directory* containing model weights saved using
|
131
131
|
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
|
132
132
|
`./my_model_directory/controlnet`.
|
133
|
-
torch_dtype (`
|
134
|
-
Override the default `torch.dtype` and load the model under this dtype.
|
135
|
-
will be automatically derived from the model's weights.
|
133
|
+
torch_dtype (`torch.dtype`, *optional*):
|
134
|
+
Override the default `torch.dtype` and load the model under this dtype.
|
136
135
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
137
136
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
138
137
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
|
-
from ...models.controlnets.controlnet import ControlNetOutput
|
8
|
-
from ...models.controlnets.controlnet_union import ControlNetUnionModel
|
9
|
-
from ...models.modeling_utils import ModelMixin
|
10
7
|
from ...utils import logging
|
8
|
+
from ..controlnets.controlnet import ControlNetOutput
|
9
|
+
from ..controlnets.controlnet_union import ControlNetUnionModel
|
10
|
+
from ..modeling_utils import ModelMixin
|
11
11
|
|
12
12
|
|
13
13
|
logger = logging.get_logger(__name__)
|
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
|
|
143
143
|
A path to a *directory* containing model weights saved using
|
144
144
|
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
|
145
145
|
`./my_model_directory/controlnet`.
|
146
|
-
torch_dtype (`
|
147
|
-
Override the default `torch.dtype` and load the model under this dtype.
|
148
|
-
will be automatically derived from the model's weights.
|
146
|
+
torch_dtype (`torch.dtype`, *optional*):
|
147
|
+
Override the default `torch.dtype` and load the model under this dtype.
|
149
148
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
150
149
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
151
150
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
diffusers/models/downsampling.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
|
|
286
286
|
|
287
287
|
|
288
288
|
class CogVideoXDownsample3D(nn.Module):
|
289
|
-
# Todo: Wait for paper
|
289
|
+
# Todo: Wait for paper release.
|
290
290
|
r"""
|
291
291
|
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
292
292
|
|
diffusers/models/embeddings.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -31,7 +31,7 @@ def get_timestep_embedding(
|
|
31
31
|
downscale_freq_shift: float = 1,
|
32
32
|
scale: float = 1,
|
33
33
|
max_period: int = 10000,
|
34
|
-
):
|
34
|
+
) -> torch.Tensor:
|
35
35
|
"""
|
36
36
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
37
37
|
|
@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
|
|
97
97
|
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
98
98
|
spatial dimensions (height and width).
|
99
99
|
temporal_size (`int`):
|
100
|
-
The temporal dimension of
|
100
|
+
The temporal dimension of positional embeddings (number of frames).
|
101
101
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
102
102
|
Scale factor for spatial grid interpolation.
|
103
103
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
|
|
169
169
|
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
170
170
|
spatial dimensions (height and width).
|
171
171
|
temporal_size (`int`):
|
172
|
-
The temporal dimension of
|
172
|
+
The temporal dimension of positional embeddings (number of frames).
|
173
173
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
174
174
|
Scale factor for spatial grid interpolation.
|
175
175
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
@@ -1149,9 +1149,7 @@ def get_1d_rotary_pos_embed(
|
|
1149
1149
|
|
1150
1150
|
theta = theta * ntk_factor
|
1151
1151
|
freqs = (
|
1152
|
-
1.0
|
1153
|
-
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
1154
|
-
/ linear_factor
|
1152
|
+
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
1155
1153
|
) # [D/2]
|
1156
1154
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
1157
1155
|
is_npu = freqs.device.type == "npu"
|
@@ -1201,11 +1199,11 @@ def apply_rotary_emb(
|
|
1201
1199
|
|
1202
1200
|
if use_real_unbind_dim == -1:
|
1203
1201
|
# Used for flux, cogvideox, hunyuan-dit
|
1204
|
-
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B,
|
1202
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
1205
1203
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
1206
1204
|
elif use_real_unbind_dim == -2:
|
1207
|
-
# Used for Stable Audio, OmniGen and
|
1208
|
-
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B,
|
1205
|
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
1206
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
1209
1207
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
1210
1208
|
else:
|
1211
1209
|
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
@@ -1327,7 +1325,7 @@ class Timesteps(nn.Module):
|
|
1327
1325
|
self.downscale_freq_shift = downscale_freq_shift
|
1328
1326
|
self.scale = scale
|
1329
1327
|
|
1330
|
-
def forward(self, timesteps):
|
1328
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
1331
1329
|
t_emb = get_timestep_embedding(
|
1332
1330
|
timesteps,
|
1333
1331
|
self.num_channels,
|
@@ -1401,7 +1399,7 @@ class ImagePositionalEmbeddings(nn.Module):
|
|
1401
1399
|
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
1402
1400
|
height and width of the latent space.
|
1403
1401
|
|
1404
|
-
For more details, see figure 10 of the dall-e paper: https://
|
1402
|
+
For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
|
1405
1403
|
|
1406
1404
|
For VQ-diffusion:
|
1407
1405
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module):
|
|
89
89
|
|
90
90
|
class FlaxTimesteps(nn.Module):
|
91
91
|
r"""
|
92
|
-
Wrapper Module for sinusoidal Time step Embeddings as described in https://
|
92
|
+
Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
|
93
93
|
|
94
94
|
Args:
|
95
95
|
dim (`int`, *optional*, defaults to `32`):
|
diffusers/models/lora.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -38,7 +38,7 @@ if is_transformers_available():
|
|
38
38
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39
39
|
|
40
40
|
|
41
|
-
def text_encoder_attn_modules(text_encoder):
|
41
|
+
def text_encoder_attn_modules(text_encoder: nn.Module):
|
42
42
|
attn_modules = []
|
43
43
|
|
44
44
|
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
|
|
52
52
|
return attn_modules
|
53
53
|
|
54
54
|
|
55
|
-
def text_encoder_mlp_modules(text_encoder):
|
55
|
+
def text_encoder_mlp_modules(text_encoder: nn.Module):
|
56
56
|
mlp_modules = []
|
57
57
|
|
58
58
|
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|