diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,887 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import List, Optional, Tuple
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from transformers import (
|
19
|
+
CLIPImageProcessor,
|
20
|
+
CLIPTextModel,
|
21
|
+
CLIPTextModelWithProjection,
|
22
|
+
CLIPTokenizer,
|
23
|
+
CLIPVisionModelWithProjection,
|
24
|
+
)
|
25
|
+
|
26
|
+
from ...configuration_utils import FrozenDict
|
27
|
+
from ...guiders import ClassifierFreeGuidance
|
28
|
+
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
29
|
+
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
30
|
+
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
31
|
+
from ...models.lora import adjust_lora_scale_text_encoder
|
32
|
+
from ...utils import (
|
33
|
+
USE_PEFT_BACKEND,
|
34
|
+
logging,
|
35
|
+
scale_lora_layers,
|
36
|
+
unscale_lora_layers,
|
37
|
+
)
|
38
|
+
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
39
|
+
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
|
40
|
+
from .modular_pipeline import StableDiffusionXLModularPipeline
|
41
|
+
|
42
|
+
|
43
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
44
|
+
|
45
|
+
|
46
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
47
|
+
def retrieve_latents(
|
48
|
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
49
|
+
):
|
50
|
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
51
|
+
return encoder_output.latent_dist.sample(generator)
|
52
|
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
53
|
+
return encoder_output.latent_dist.mode()
|
54
|
+
elif hasattr(encoder_output, "latents"):
|
55
|
+
return encoder_output.latents
|
56
|
+
else:
|
57
|
+
raise AttributeError("Could not access latents of provided encoder_output")
|
58
|
+
|
59
|
+
|
60
|
+
class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
|
61
|
+
model_name = "stable-diffusion-xl"
|
62
|
+
|
63
|
+
@property
|
64
|
+
def description(self) -> str:
|
65
|
+
return (
|
66
|
+
"IP Adapter step that prepares ip adapter image embeddings.\n"
|
67
|
+
"Note that this step only prepares the embeddings - in order for it to work correctly, "
|
68
|
+
"you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
|
69
|
+
"See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
|
70
|
+
" for more details"
|
71
|
+
)
|
72
|
+
|
73
|
+
@property
|
74
|
+
def expected_components(self) -> List[ComponentSpec]:
|
75
|
+
return [
|
76
|
+
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
|
77
|
+
ComponentSpec(
|
78
|
+
"feature_extractor",
|
79
|
+
CLIPImageProcessor,
|
80
|
+
config=FrozenDict({"size": 224, "crop_size": 224}),
|
81
|
+
default_creation_method="from_config",
|
82
|
+
),
|
83
|
+
ComponentSpec("unet", UNet2DConditionModel),
|
84
|
+
ComponentSpec(
|
85
|
+
"guider",
|
86
|
+
ClassifierFreeGuidance,
|
87
|
+
config=FrozenDict({"guidance_scale": 7.5}),
|
88
|
+
default_creation_method="from_config",
|
89
|
+
),
|
90
|
+
]
|
91
|
+
|
92
|
+
@property
|
93
|
+
def inputs(self) -> List[InputParam]:
|
94
|
+
return [
|
95
|
+
InputParam(
|
96
|
+
"ip_adapter_image",
|
97
|
+
PipelineImageInput,
|
98
|
+
required=True,
|
99
|
+
description="The image(s) to be used as ip adapter",
|
100
|
+
)
|
101
|
+
]
|
102
|
+
|
103
|
+
@property
|
104
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
105
|
+
return [
|
106
|
+
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
|
107
|
+
OutputParam(
|
108
|
+
"negative_ip_adapter_embeds",
|
109
|
+
type_hint=torch.Tensor,
|
110
|
+
description="Negative IP adapter image embeddings",
|
111
|
+
),
|
112
|
+
]
|
113
|
+
|
114
|
+
@staticmethod
|
115
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
|
116
|
+
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
|
117
|
+
dtype = next(components.image_encoder.parameters()).dtype
|
118
|
+
|
119
|
+
if not isinstance(image, torch.Tensor):
|
120
|
+
image = components.feature_extractor(image, return_tensors="pt").pixel_values
|
121
|
+
|
122
|
+
image = image.to(device=device, dtype=dtype)
|
123
|
+
if output_hidden_states:
|
124
|
+
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
125
|
+
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
126
|
+
uncond_image_enc_hidden_states = components.image_encoder(
|
127
|
+
torch.zeros_like(image), output_hidden_states=True
|
128
|
+
).hidden_states[-2]
|
129
|
+
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
130
|
+
num_images_per_prompt, dim=0
|
131
|
+
)
|
132
|
+
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
133
|
+
else:
|
134
|
+
image_embeds = components.image_encoder(image).image_embeds
|
135
|
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
136
|
+
uncond_image_embeds = torch.zeros_like(image_embeds)
|
137
|
+
|
138
|
+
return image_embeds, uncond_image_embeds
|
139
|
+
|
140
|
+
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
141
|
+
def prepare_ip_adapter_image_embeds(
|
142
|
+
self,
|
143
|
+
components,
|
144
|
+
ip_adapter_image,
|
145
|
+
ip_adapter_image_embeds,
|
146
|
+
device,
|
147
|
+
num_images_per_prompt,
|
148
|
+
prepare_unconditional_embeds,
|
149
|
+
):
|
150
|
+
image_embeds = []
|
151
|
+
if prepare_unconditional_embeds:
|
152
|
+
negative_image_embeds = []
|
153
|
+
if ip_adapter_image_embeds is None:
|
154
|
+
if not isinstance(ip_adapter_image, list):
|
155
|
+
ip_adapter_image = [ip_adapter_image]
|
156
|
+
|
157
|
+
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
|
158
|
+
raise ValueError(
|
159
|
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
160
|
+
)
|
161
|
+
|
162
|
+
for single_ip_adapter_image, image_proj_layer in zip(
|
163
|
+
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
|
164
|
+
):
|
165
|
+
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
|
166
|
+
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
167
|
+
components, single_ip_adapter_image, device, 1, output_hidden_state
|
168
|
+
)
|
169
|
+
|
170
|
+
image_embeds.append(single_image_embeds[None, :])
|
171
|
+
if prepare_unconditional_embeds:
|
172
|
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
173
|
+
else:
|
174
|
+
for single_image_embeds in ip_adapter_image_embeds:
|
175
|
+
if prepare_unconditional_embeds:
|
176
|
+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
177
|
+
negative_image_embeds.append(single_negative_image_embeds)
|
178
|
+
image_embeds.append(single_image_embeds)
|
179
|
+
|
180
|
+
ip_adapter_image_embeds = []
|
181
|
+
for i, single_image_embeds in enumerate(image_embeds):
|
182
|
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
183
|
+
if prepare_unconditional_embeds:
|
184
|
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
185
|
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
186
|
+
|
187
|
+
single_image_embeds = single_image_embeds.to(device=device)
|
188
|
+
ip_adapter_image_embeds.append(single_image_embeds)
|
189
|
+
|
190
|
+
return ip_adapter_image_embeds
|
191
|
+
|
192
|
+
@torch.no_grad()
|
193
|
+
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
194
|
+
block_state = self.get_block_state(state)
|
195
|
+
|
196
|
+
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
197
|
+
block_state.device = components._execution_device
|
198
|
+
|
199
|
+
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
|
200
|
+
components,
|
201
|
+
ip_adapter_image=block_state.ip_adapter_image,
|
202
|
+
ip_adapter_image_embeds=None,
|
203
|
+
device=block_state.device,
|
204
|
+
num_images_per_prompt=1,
|
205
|
+
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
|
206
|
+
)
|
207
|
+
if block_state.prepare_unconditional_embeds:
|
208
|
+
block_state.negative_ip_adapter_embeds = []
|
209
|
+
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
|
210
|
+
negative_image_embeds, image_embeds = image_embeds.chunk(2)
|
211
|
+
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
|
212
|
+
block_state.ip_adapter_embeds[i] = image_embeds
|
213
|
+
|
214
|
+
self.set_block_state(state, block_state)
|
215
|
+
return components, state
|
216
|
+
|
217
|
+
|
218
|
+
class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
|
219
|
+
model_name = "stable-diffusion-xl"
|
220
|
+
|
221
|
+
@property
|
222
|
+
def description(self) -> str:
|
223
|
+
return "Text Encoder step that generate text_embeddings to guide the image generation"
|
224
|
+
|
225
|
+
@property
|
226
|
+
def expected_components(self) -> List[ComponentSpec]:
|
227
|
+
return [
|
228
|
+
ComponentSpec("text_encoder", CLIPTextModel),
|
229
|
+
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
|
230
|
+
ComponentSpec("tokenizer", CLIPTokenizer),
|
231
|
+
ComponentSpec("tokenizer_2", CLIPTokenizer),
|
232
|
+
ComponentSpec(
|
233
|
+
"guider",
|
234
|
+
ClassifierFreeGuidance,
|
235
|
+
config=FrozenDict({"guidance_scale": 7.5}),
|
236
|
+
default_creation_method="from_config",
|
237
|
+
),
|
238
|
+
]
|
239
|
+
|
240
|
+
@property
|
241
|
+
def expected_configs(self) -> List[ConfigSpec]:
|
242
|
+
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
|
243
|
+
|
244
|
+
@property
|
245
|
+
def inputs(self) -> List[InputParam]:
|
246
|
+
return [
|
247
|
+
InputParam("prompt"),
|
248
|
+
InputParam("prompt_2"),
|
249
|
+
InputParam("negative_prompt"),
|
250
|
+
InputParam("negative_prompt_2"),
|
251
|
+
InputParam("cross_attention_kwargs"),
|
252
|
+
InputParam("clip_skip"),
|
253
|
+
]
|
254
|
+
|
255
|
+
@property
|
256
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
257
|
+
return [
|
258
|
+
OutputParam(
|
259
|
+
"prompt_embeds",
|
260
|
+
type_hint=torch.Tensor,
|
261
|
+
kwargs_type="guider_input_fields",
|
262
|
+
description="text embeddings used to guide the image generation",
|
263
|
+
),
|
264
|
+
OutputParam(
|
265
|
+
"negative_prompt_embeds",
|
266
|
+
type_hint=torch.Tensor,
|
267
|
+
kwargs_type="guider_input_fields",
|
268
|
+
description="negative text embeddings used to guide the image generation",
|
269
|
+
),
|
270
|
+
OutputParam(
|
271
|
+
"pooled_prompt_embeds",
|
272
|
+
type_hint=torch.Tensor,
|
273
|
+
kwargs_type="guider_input_fields",
|
274
|
+
description="pooled text embeddings used to guide the image generation",
|
275
|
+
),
|
276
|
+
OutputParam(
|
277
|
+
"negative_pooled_prompt_embeds",
|
278
|
+
type_hint=torch.Tensor,
|
279
|
+
kwargs_type="guider_input_fields",
|
280
|
+
description="negative pooled text embeddings used to guide the image generation",
|
281
|
+
),
|
282
|
+
]
|
283
|
+
|
284
|
+
@staticmethod
|
285
|
+
def check_inputs(block_state):
|
286
|
+
if block_state.prompt is not None and (
|
287
|
+
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
288
|
+
):
|
289
|
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
290
|
+
elif block_state.prompt_2 is not None and (
|
291
|
+
not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
|
292
|
+
):
|
293
|
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
|
294
|
+
|
295
|
+
@staticmethod
|
296
|
+
def encode_prompt(
|
297
|
+
components,
|
298
|
+
prompt: str,
|
299
|
+
prompt_2: Optional[str] = None,
|
300
|
+
device: Optional[torch.device] = None,
|
301
|
+
num_images_per_prompt: int = 1,
|
302
|
+
prepare_unconditional_embeds: bool = True,
|
303
|
+
negative_prompt: Optional[str] = None,
|
304
|
+
negative_prompt_2: Optional[str] = None,
|
305
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
306
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
307
|
+
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
308
|
+
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
309
|
+
lora_scale: Optional[float] = None,
|
310
|
+
clip_skip: Optional[int] = None,
|
311
|
+
):
|
312
|
+
r"""
|
313
|
+
Encodes the prompt into text encoder hidden states.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
prompt (`str` or `List[str]`, *optional*):
|
317
|
+
prompt to be encoded
|
318
|
+
prompt_2 (`str` or `List[str]`, *optional*):
|
319
|
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
320
|
+
used in both text-encoders
|
321
|
+
device: (`torch.device`):
|
322
|
+
torch device
|
323
|
+
num_images_per_prompt (`int`):
|
324
|
+
number of images that should be generated per prompt
|
325
|
+
prepare_unconditional_embeds (`bool`):
|
326
|
+
whether to use prepare unconditional embeddings or not
|
327
|
+
negative_prompt (`str` or `List[str]`, *optional*):
|
328
|
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
329
|
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
330
|
+
less than `1`).
|
331
|
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
332
|
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
333
|
+
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
334
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
335
|
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
336
|
+
provided, text embeddings will be generated from `prompt` input argument.
|
337
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
338
|
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
339
|
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
340
|
+
argument.
|
341
|
+
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
342
|
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
343
|
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
344
|
+
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
345
|
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
346
|
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
347
|
+
input argument.
|
348
|
+
lora_scale (`float`, *optional*):
|
349
|
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
350
|
+
clip_skip (`int`, *optional*):
|
351
|
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
352
|
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
353
|
+
"""
|
354
|
+
device = device or components._execution_device
|
355
|
+
|
356
|
+
# set lora scale so that monkey patched LoRA
|
357
|
+
# function of text encoder can correctly access it
|
358
|
+
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
|
359
|
+
components._lora_scale = lora_scale
|
360
|
+
|
361
|
+
# dynamically adjust the LoRA scale
|
362
|
+
if components.text_encoder is not None:
|
363
|
+
if not USE_PEFT_BACKEND:
|
364
|
+
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
|
365
|
+
else:
|
366
|
+
scale_lora_layers(components.text_encoder, lora_scale)
|
367
|
+
|
368
|
+
if components.text_encoder_2 is not None:
|
369
|
+
if not USE_PEFT_BACKEND:
|
370
|
+
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
|
371
|
+
else:
|
372
|
+
scale_lora_layers(components.text_encoder_2, lora_scale)
|
373
|
+
|
374
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
375
|
+
|
376
|
+
if prompt is not None:
|
377
|
+
batch_size = len(prompt)
|
378
|
+
else:
|
379
|
+
batch_size = prompt_embeds.shape[0]
|
380
|
+
|
381
|
+
# Define tokenizers and text encoders
|
382
|
+
tokenizers = (
|
383
|
+
[components.tokenizer, components.tokenizer_2]
|
384
|
+
if components.tokenizer is not None
|
385
|
+
else [components.tokenizer_2]
|
386
|
+
)
|
387
|
+
text_encoders = (
|
388
|
+
[components.text_encoder, components.text_encoder_2]
|
389
|
+
if components.text_encoder is not None
|
390
|
+
else [components.text_encoder_2]
|
391
|
+
)
|
392
|
+
|
393
|
+
if prompt_embeds is None:
|
394
|
+
prompt_2 = prompt_2 or prompt
|
395
|
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
396
|
+
|
397
|
+
# textual inversion: process multi-vector tokens if necessary
|
398
|
+
prompt_embeds_list = []
|
399
|
+
prompts = [prompt, prompt_2]
|
400
|
+
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
|
401
|
+
if isinstance(components, TextualInversionLoaderMixin):
|
402
|
+
prompt = components.maybe_convert_prompt(prompt, tokenizer)
|
403
|
+
|
404
|
+
text_inputs = tokenizer(
|
405
|
+
prompt,
|
406
|
+
padding="max_length",
|
407
|
+
max_length=tokenizer.model_max_length,
|
408
|
+
truncation=True,
|
409
|
+
return_tensors="pt",
|
410
|
+
)
|
411
|
+
|
412
|
+
text_input_ids = text_inputs.input_ids
|
413
|
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
414
|
+
|
415
|
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
416
|
+
text_input_ids, untruncated_ids
|
417
|
+
):
|
418
|
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
419
|
+
logger.warning(
|
420
|
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
421
|
+
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
422
|
+
)
|
423
|
+
|
424
|
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
425
|
+
|
426
|
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
427
|
+
pooled_prompt_embeds = prompt_embeds[0]
|
428
|
+
if clip_skip is None:
|
429
|
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
430
|
+
else:
|
431
|
+
# "2" because SDXL always indexes from the penultimate layer.
|
432
|
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
433
|
+
|
434
|
+
prompt_embeds_list.append(prompt_embeds)
|
435
|
+
|
436
|
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
437
|
+
|
438
|
+
# get unconditional embeddings for classifier free guidance
|
439
|
+
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
|
440
|
+
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
|
441
|
+
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
442
|
+
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
443
|
+
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
|
444
|
+
negative_prompt = negative_prompt or ""
|
445
|
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
446
|
+
|
447
|
+
# normalize str to list
|
448
|
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
449
|
+
negative_prompt_2 = (
|
450
|
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
451
|
+
)
|
452
|
+
|
453
|
+
uncond_tokens: List[str]
|
454
|
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
455
|
+
raise TypeError(
|
456
|
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
457
|
+
f" {type(prompt)}."
|
458
|
+
)
|
459
|
+
elif batch_size != len(negative_prompt):
|
460
|
+
raise ValueError(
|
461
|
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
462
|
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
463
|
+
" the batch size of `prompt`."
|
464
|
+
)
|
465
|
+
else:
|
466
|
+
uncond_tokens = [negative_prompt, negative_prompt_2]
|
467
|
+
|
468
|
+
negative_prompt_embeds_list = []
|
469
|
+
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
|
470
|
+
if isinstance(components, TextualInversionLoaderMixin):
|
471
|
+
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
|
472
|
+
|
473
|
+
max_length = prompt_embeds.shape[1]
|
474
|
+
uncond_input = tokenizer(
|
475
|
+
negative_prompt,
|
476
|
+
padding="max_length",
|
477
|
+
max_length=max_length,
|
478
|
+
truncation=True,
|
479
|
+
return_tensors="pt",
|
480
|
+
)
|
481
|
+
|
482
|
+
negative_prompt_embeds = text_encoder(
|
483
|
+
uncond_input.input_ids.to(device),
|
484
|
+
output_hidden_states=True,
|
485
|
+
)
|
486
|
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
487
|
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
488
|
+
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
489
|
+
|
490
|
+
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
491
|
+
|
492
|
+
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
|
493
|
+
|
494
|
+
if components.text_encoder_2 is not None:
|
495
|
+
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
|
496
|
+
else:
|
497
|
+
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
498
|
+
|
499
|
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
500
|
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
501
|
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
502
|
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
503
|
+
|
504
|
+
if prepare_unconditional_embeds:
|
505
|
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
506
|
+
seq_len = negative_prompt_embeds.shape[1]
|
507
|
+
|
508
|
+
if components.text_encoder_2 is not None:
|
509
|
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
510
|
+
dtype=components.text_encoder_2.dtype, device=device
|
511
|
+
)
|
512
|
+
else:
|
513
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
|
514
|
+
|
515
|
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
516
|
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
517
|
+
|
518
|
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
519
|
+
bs_embed * num_images_per_prompt, -1
|
520
|
+
)
|
521
|
+
if prepare_unconditional_embeds:
|
522
|
+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
523
|
+
bs_embed * num_images_per_prompt, -1
|
524
|
+
)
|
525
|
+
|
526
|
+
if components.text_encoder is not None:
|
527
|
+
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
528
|
+
# Retrieve the original scale by scaling back the LoRA layers
|
529
|
+
unscale_lora_layers(components.text_encoder, lora_scale)
|
530
|
+
|
531
|
+
if components.text_encoder_2 is not None:
|
532
|
+
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
|
533
|
+
# Retrieve the original scale by scaling back the LoRA layers
|
534
|
+
unscale_lora_layers(components.text_encoder_2, lora_scale)
|
535
|
+
|
536
|
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
537
|
+
|
538
|
+
@torch.no_grad()
|
539
|
+
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
540
|
+
# Get inputs and intermediates
|
541
|
+
block_state = self.get_block_state(state)
|
542
|
+
self.check_inputs(block_state)
|
543
|
+
|
544
|
+
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
|
545
|
+
block_state.device = components._execution_device
|
546
|
+
|
547
|
+
# Encode input prompt
|
548
|
+
block_state.text_encoder_lora_scale = (
|
549
|
+
block_state.cross_attention_kwargs.get("scale", None)
|
550
|
+
if block_state.cross_attention_kwargs is not None
|
551
|
+
else None
|
552
|
+
)
|
553
|
+
(
|
554
|
+
block_state.prompt_embeds,
|
555
|
+
block_state.negative_prompt_embeds,
|
556
|
+
block_state.pooled_prompt_embeds,
|
557
|
+
block_state.negative_pooled_prompt_embeds,
|
558
|
+
) = self.encode_prompt(
|
559
|
+
components,
|
560
|
+
block_state.prompt,
|
561
|
+
block_state.prompt_2,
|
562
|
+
block_state.device,
|
563
|
+
1,
|
564
|
+
block_state.prepare_unconditional_embeds,
|
565
|
+
block_state.negative_prompt,
|
566
|
+
block_state.negative_prompt_2,
|
567
|
+
prompt_embeds=None,
|
568
|
+
negative_prompt_embeds=None,
|
569
|
+
pooled_prompt_embeds=None,
|
570
|
+
negative_pooled_prompt_embeds=None,
|
571
|
+
lora_scale=block_state.text_encoder_lora_scale,
|
572
|
+
clip_skip=block_state.clip_skip,
|
573
|
+
)
|
574
|
+
# Add outputs
|
575
|
+
self.set_block_state(state, block_state)
|
576
|
+
return components, state
|
577
|
+
|
578
|
+
|
579
|
+
class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
|
580
|
+
model_name = "stable-diffusion-xl"
|
581
|
+
|
582
|
+
@property
|
583
|
+
def description(self) -> str:
|
584
|
+
return "Vae Encoder step that encode the input image into a latent representation"
|
585
|
+
|
586
|
+
@property
|
587
|
+
def expected_components(self) -> List[ComponentSpec]:
|
588
|
+
return [
|
589
|
+
ComponentSpec("vae", AutoencoderKL),
|
590
|
+
ComponentSpec(
|
591
|
+
"image_processor",
|
592
|
+
VaeImageProcessor,
|
593
|
+
config=FrozenDict({"vae_scale_factor": 8}),
|
594
|
+
default_creation_method="from_config",
|
595
|
+
),
|
596
|
+
]
|
597
|
+
|
598
|
+
@property
|
599
|
+
def inputs(self) -> List[InputParam]:
|
600
|
+
return [
|
601
|
+
InputParam("image", required=True),
|
602
|
+
InputParam("height"),
|
603
|
+
InputParam("width"),
|
604
|
+
InputParam("generator"),
|
605
|
+
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
606
|
+
InputParam(
|
607
|
+
"preprocess_kwargs",
|
608
|
+
type_hint=Optional[dict],
|
609
|
+
description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
|
610
|
+
),
|
611
|
+
]
|
612
|
+
|
613
|
+
@property
|
614
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
615
|
+
return [
|
616
|
+
OutputParam(
|
617
|
+
"image_latents",
|
618
|
+
type_hint=torch.Tensor,
|
619
|
+
description="The latents representing the reference image for image-to-image/inpainting generation",
|
620
|
+
)
|
621
|
+
]
|
622
|
+
|
623
|
+
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
624
|
+
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
625
|
+
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
626
|
+
latents_mean = latents_std = None
|
627
|
+
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
628
|
+
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
629
|
+
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
630
|
+
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
631
|
+
|
632
|
+
dtype = image.dtype
|
633
|
+
if components.vae.config.force_upcast:
|
634
|
+
image = image.float()
|
635
|
+
components.vae.to(dtype=torch.float32)
|
636
|
+
|
637
|
+
if isinstance(generator, list):
|
638
|
+
image_latents = [
|
639
|
+
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
640
|
+
for i in range(image.shape[0])
|
641
|
+
]
|
642
|
+
image_latents = torch.cat(image_latents, dim=0)
|
643
|
+
else:
|
644
|
+
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
645
|
+
|
646
|
+
if components.vae.config.force_upcast:
|
647
|
+
components.vae.to(dtype)
|
648
|
+
|
649
|
+
image_latents = image_latents.to(dtype)
|
650
|
+
if latents_mean is not None and latents_std is not None:
|
651
|
+
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
652
|
+
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
653
|
+
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
|
654
|
+
else:
|
655
|
+
image_latents = components.vae.config.scaling_factor * image_latents
|
656
|
+
|
657
|
+
return image_latents
|
658
|
+
|
659
|
+
@torch.no_grad()
|
660
|
+
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
661
|
+
block_state = self.get_block_state(state)
|
662
|
+
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
|
663
|
+
block_state.device = components._execution_device
|
664
|
+
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
665
|
+
|
666
|
+
image = components.image_processor.preprocess(
|
667
|
+
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
|
668
|
+
)
|
669
|
+
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
670
|
+
block_state.batch_size = image.shape[0]
|
671
|
+
|
672
|
+
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
|
673
|
+
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
|
674
|
+
raise ValueError(
|
675
|
+
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
676
|
+
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
|
677
|
+
)
|
678
|
+
|
679
|
+
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
680
|
+
|
681
|
+
self.set_block_state(state, block_state)
|
682
|
+
|
683
|
+
return components, state
|
684
|
+
|
685
|
+
|
686
|
+
class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
|
687
|
+
model_name = "stable-diffusion-xl"
|
688
|
+
|
689
|
+
@property
|
690
|
+
def expected_components(self) -> List[ComponentSpec]:
|
691
|
+
return [
|
692
|
+
ComponentSpec("vae", AutoencoderKL),
|
693
|
+
ComponentSpec(
|
694
|
+
"image_processor",
|
695
|
+
VaeImageProcessor,
|
696
|
+
config=FrozenDict({"vae_scale_factor": 8}),
|
697
|
+
default_creation_method="from_config",
|
698
|
+
),
|
699
|
+
ComponentSpec(
|
700
|
+
"mask_processor",
|
701
|
+
VaeImageProcessor,
|
702
|
+
config=FrozenDict(
|
703
|
+
{"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
|
704
|
+
),
|
705
|
+
default_creation_method="from_config",
|
706
|
+
),
|
707
|
+
]
|
708
|
+
|
709
|
+
@property
|
710
|
+
def description(self) -> str:
|
711
|
+
return "Vae encoder step that prepares the image and mask for the inpainting process"
|
712
|
+
|
713
|
+
@property
|
714
|
+
def inputs(self) -> List[InputParam]:
|
715
|
+
return [
|
716
|
+
InputParam("height"),
|
717
|
+
InputParam("width"),
|
718
|
+
InputParam("image", required=True),
|
719
|
+
InputParam("mask_image", required=True),
|
720
|
+
InputParam("padding_mask_crop"),
|
721
|
+
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
722
|
+
InputParam("generator"),
|
723
|
+
]
|
724
|
+
|
725
|
+
@property
|
726
|
+
def intermediate_outputs(self) -> List[OutputParam]:
|
727
|
+
return [
|
728
|
+
OutputParam(
|
729
|
+
"image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
|
730
|
+
),
|
731
|
+
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
|
732
|
+
OutputParam(
|
733
|
+
"masked_image_latents",
|
734
|
+
type_hint=torch.Tensor,
|
735
|
+
description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
|
736
|
+
),
|
737
|
+
OutputParam(
|
738
|
+
"crops_coords",
|
739
|
+
type_hint=Optional[Tuple[int, int]],
|
740
|
+
description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
|
741
|
+
),
|
742
|
+
]
|
743
|
+
|
744
|
+
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
|
745
|
+
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
|
746
|
+
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
|
747
|
+
latents_mean = latents_std = None
|
748
|
+
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
|
749
|
+
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
|
750
|
+
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
|
751
|
+
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
|
752
|
+
|
753
|
+
dtype = image.dtype
|
754
|
+
if components.vae.config.force_upcast:
|
755
|
+
image = image.float()
|
756
|
+
components.vae.to(dtype=torch.float32)
|
757
|
+
|
758
|
+
if isinstance(generator, list):
|
759
|
+
image_latents = [
|
760
|
+
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
|
761
|
+
for i in range(image.shape[0])
|
762
|
+
]
|
763
|
+
image_latents = torch.cat(image_latents, dim=0)
|
764
|
+
else:
|
765
|
+
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
|
766
|
+
|
767
|
+
if components.vae.config.force_upcast:
|
768
|
+
components.vae.to(dtype)
|
769
|
+
|
770
|
+
image_latents = image_latents.to(dtype)
|
771
|
+
if latents_mean is not None and latents_std is not None:
|
772
|
+
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
|
773
|
+
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
|
774
|
+
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
|
775
|
+
else:
|
776
|
+
image_latents = components.vae.config.scaling_factor * image_latents
|
777
|
+
|
778
|
+
return image_latents
|
779
|
+
|
780
|
+
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
|
781
|
+
# do not accept do_classifier_free_guidance
|
782
|
+
def prepare_mask_latents(
|
783
|
+
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
|
784
|
+
):
|
785
|
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
786
|
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
787
|
+
# and half precision
|
788
|
+
mask = torch.nn.functional.interpolate(
|
789
|
+
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
|
790
|
+
)
|
791
|
+
mask = mask.to(device=device, dtype=dtype)
|
792
|
+
|
793
|
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
794
|
+
if mask.shape[0] < batch_size:
|
795
|
+
if not batch_size % mask.shape[0] == 0:
|
796
|
+
raise ValueError(
|
797
|
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
798
|
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
799
|
+
" of masks that you pass is divisible by the total requested batch size."
|
800
|
+
)
|
801
|
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
802
|
+
|
803
|
+
if masked_image is not None and masked_image.shape[1] == 4:
|
804
|
+
masked_image_latents = masked_image
|
805
|
+
else:
|
806
|
+
masked_image_latents = None
|
807
|
+
|
808
|
+
if masked_image is not None:
|
809
|
+
if masked_image_latents is None:
|
810
|
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
811
|
+
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
|
812
|
+
|
813
|
+
if masked_image_latents.shape[0] < batch_size:
|
814
|
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
815
|
+
raise ValueError(
|
816
|
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
817
|
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
818
|
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
819
|
+
)
|
820
|
+
masked_image_latents = masked_image_latents.repeat(
|
821
|
+
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
822
|
+
)
|
823
|
+
|
824
|
+
# aligning device to prevent device errors when concating it with the latent model input
|
825
|
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
826
|
+
|
827
|
+
return mask, masked_image_latents
|
828
|
+
|
829
|
+
@torch.no_grad()
|
830
|
+
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
|
831
|
+
block_state = self.get_block_state(state)
|
832
|
+
|
833
|
+
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
|
834
|
+
block_state.device = components._execution_device
|
835
|
+
|
836
|
+
if block_state.height is None:
|
837
|
+
block_state.height = components.default_height
|
838
|
+
if block_state.width is None:
|
839
|
+
block_state.width = components.default_width
|
840
|
+
|
841
|
+
if block_state.padding_mask_crop is not None:
|
842
|
+
block_state.crops_coords = components.mask_processor.get_crop_region(
|
843
|
+
block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
|
844
|
+
)
|
845
|
+
block_state.resize_mode = "fill"
|
846
|
+
else:
|
847
|
+
block_state.crops_coords = None
|
848
|
+
block_state.resize_mode = "default"
|
849
|
+
|
850
|
+
image = components.image_processor.preprocess(
|
851
|
+
block_state.image,
|
852
|
+
height=block_state.height,
|
853
|
+
width=block_state.width,
|
854
|
+
crops_coords=block_state.crops_coords,
|
855
|
+
resize_mode=block_state.resize_mode,
|
856
|
+
)
|
857
|
+
image = image.to(dtype=torch.float32)
|
858
|
+
|
859
|
+
mask = components.mask_processor.preprocess(
|
860
|
+
block_state.mask_image,
|
861
|
+
height=block_state.height,
|
862
|
+
width=block_state.width,
|
863
|
+
resize_mode=block_state.resize_mode,
|
864
|
+
crops_coords=block_state.crops_coords,
|
865
|
+
)
|
866
|
+
block_state.masked_image = image * (mask < 0.5)
|
867
|
+
|
868
|
+
block_state.batch_size = image.shape[0]
|
869
|
+
image = image.to(device=block_state.device, dtype=block_state.dtype)
|
870
|
+
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
|
871
|
+
|
872
|
+
# 7. Prepare mask latent variables
|
873
|
+
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
|
874
|
+
components,
|
875
|
+
mask,
|
876
|
+
block_state.masked_image,
|
877
|
+
block_state.batch_size,
|
878
|
+
block_state.height,
|
879
|
+
block_state.width,
|
880
|
+
block_state.dtype,
|
881
|
+
block_state.device,
|
882
|
+
block_state.generator,
|
883
|
+
)
|
884
|
+
|
885
|
+
self.set_block_state(state, block_state)
|
886
|
+
|
887
|
+
return components, state
|