diffusers 0.33.1__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +13 -10
- diffusers/pipelines/wan/pipeline_wan_i2v.py +38 -18
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/METADATA +70 -55
- diffusers-0.34.0.dist-info/RECORD +639 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,742 @@
|
|
1
|
+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
25
|
+
from ...utils.import_utils import is_torch_npu_available
|
26
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
+
from ..attention import FeedForward
|
28
|
+
from ..attention_processor import (
|
29
|
+
Attention,
|
30
|
+
AttentionProcessor,
|
31
|
+
FluxAttnProcessor2_0,
|
32
|
+
FluxAttnProcessor2_0_NPU,
|
33
|
+
FusedFluxAttnProcessor2_0,
|
34
|
+
)
|
35
|
+
from ..cache_utils import CacheMixin
|
36
|
+
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
|
37
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
38
|
+
from ..modeling_utils import ModelMixin
|
39
|
+
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
|
40
|
+
|
41
|
+
|
42
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43
|
+
|
44
|
+
|
45
|
+
class ChromaAdaLayerNormZeroPruned(nn.Module):
|
46
|
+
r"""
|
47
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
embedding_dim (`int`): The size of each embedding vector.
|
51
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
|
55
|
+
super().__init__()
|
56
|
+
if num_embeddings is not None:
|
57
|
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
58
|
+
else:
|
59
|
+
self.emb = None
|
60
|
+
|
61
|
+
if norm_type == "layer_norm":
|
62
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
63
|
+
elif norm_type == "fp32_layer_norm":
|
64
|
+
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
|
65
|
+
else:
|
66
|
+
raise ValueError(
|
67
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
68
|
+
)
|
69
|
+
|
70
|
+
def forward(
|
71
|
+
self,
|
72
|
+
x: torch.Tensor,
|
73
|
+
timestep: Optional[torch.Tensor] = None,
|
74
|
+
class_labels: Optional[torch.LongTensor] = None,
|
75
|
+
hidden_dtype: Optional[torch.dtype] = None,
|
76
|
+
emb: Optional[torch.Tensor] = None,
|
77
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
78
|
+
if self.emb is not None:
|
79
|
+
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
80
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
|
81
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
82
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
83
|
+
|
84
|
+
|
85
|
+
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
|
86
|
+
r"""
|
87
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
88
|
+
|
89
|
+
Parameters:
|
90
|
+
embedding_dim (`int`): The size of each embedding vector.
|
91
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
92
|
+
"""
|
93
|
+
|
94
|
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
95
|
+
super().__init__()
|
96
|
+
|
97
|
+
if norm_type == "layer_norm":
|
98
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
99
|
+
else:
|
100
|
+
raise ValueError(
|
101
|
+
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
|
102
|
+
)
|
103
|
+
|
104
|
+
def forward(
|
105
|
+
self,
|
106
|
+
x: torch.Tensor,
|
107
|
+
emb: Optional[torch.Tensor] = None,
|
108
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
109
|
+
shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
|
110
|
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
111
|
+
return x, gate_msa
|
112
|
+
|
113
|
+
|
114
|
+
class ChromaAdaLayerNormContinuousPruned(nn.Module):
|
115
|
+
r"""
|
116
|
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
117
|
+
|
118
|
+
Args:
|
119
|
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
120
|
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
121
|
+
elementwise_affine (`bool`, defaults to `True`):
|
122
|
+
Boolean flag to denote if affine transformation should be applied.
|
123
|
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
124
|
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
125
|
+
norm_type (`str`, defaults to `"layer_norm"`):
|
126
|
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
127
|
+
"""
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
embedding_dim: int,
|
132
|
+
conditioning_embedding_dim: int,
|
133
|
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
134
|
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
135
|
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
136
|
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
137
|
+
# set `elementwise_affine` to False.
|
138
|
+
elementwise_affine=True,
|
139
|
+
eps=1e-5,
|
140
|
+
bias=True,
|
141
|
+
norm_type="layer_norm",
|
142
|
+
):
|
143
|
+
super().__init__()
|
144
|
+
if norm_type == "layer_norm":
|
145
|
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
146
|
+
elif norm_type == "rms_norm":
|
147
|
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
148
|
+
else:
|
149
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
150
|
+
|
151
|
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
152
|
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
153
|
+
shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
|
154
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
155
|
+
return x
|
156
|
+
|
157
|
+
|
158
|
+
class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
|
159
|
+
def __init__(self, num_channels: int, out_dim: int):
|
160
|
+
super().__init__()
|
161
|
+
|
162
|
+
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
163
|
+
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
164
|
+
|
165
|
+
self.register_buffer(
|
166
|
+
"mod_proj",
|
167
|
+
get_timestep_embedding(
|
168
|
+
torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
|
169
|
+
),
|
170
|
+
persistent=False,
|
171
|
+
)
|
172
|
+
|
173
|
+
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
|
174
|
+
mod_index_length = self.mod_proj.shape[0]
|
175
|
+
batch_size = timestep.shape[0]
|
176
|
+
|
177
|
+
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
178
|
+
guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
|
179
|
+
dtype=timestep.dtype, device=timestep.device
|
180
|
+
)
|
181
|
+
|
182
|
+
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
|
183
|
+
timestep_guidance = (
|
184
|
+
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
185
|
+
)
|
186
|
+
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
|
187
|
+
return input_vec.to(timestep.dtype)
|
188
|
+
|
189
|
+
|
190
|
+
class ChromaApproximator(nn.Module):
|
191
|
+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
|
192
|
+
super().__init__()
|
193
|
+
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
194
|
+
self.layers = nn.ModuleList(
|
195
|
+
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
|
196
|
+
)
|
197
|
+
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
|
198
|
+
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
199
|
+
|
200
|
+
def forward(self, x):
|
201
|
+
x = self.in_proj(x)
|
202
|
+
|
203
|
+
for layer, norms in zip(self.layers, self.norms):
|
204
|
+
x = x + layer(norms(x))
|
205
|
+
|
206
|
+
return self.out_proj(x)
|
207
|
+
|
208
|
+
|
209
|
+
@maybe_allow_in_graph
|
210
|
+
class ChromaSingleTransformerBlock(nn.Module):
|
211
|
+
def __init__(
|
212
|
+
self,
|
213
|
+
dim: int,
|
214
|
+
num_attention_heads: int,
|
215
|
+
attention_head_dim: int,
|
216
|
+
mlp_ratio: float = 4.0,
|
217
|
+
):
|
218
|
+
super().__init__()
|
219
|
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
220
|
+
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
|
221
|
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
222
|
+
self.act_mlp = nn.GELU(approximate="tanh")
|
223
|
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
224
|
+
|
225
|
+
if is_torch_npu_available():
|
226
|
+
deprecation_message = (
|
227
|
+
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
228
|
+
"should be set explicitly using the `set_attn_processor` method."
|
229
|
+
)
|
230
|
+
deprecate("npu_processor", "0.34.0", deprecation_message)
|
231
|
+
processor = FluxAttnProcessor2_0_NPU()
|
232
|
+
else:
|
233
|
+
processor = FluxAttnProcessor2_0()
|
234
|
+
|
235
|
+
self.attn = Attention(
|
236
|
+
query_dim=dim,
|
237
|
+
cross_attention_dim=None,
|
238
|
+
dim_head=attention_head_dim,
|
239
|
+
heads=num_attention_heads,
|
240
|
+
out_dim=dim,
|
241
|
+
bias=True,
|
242
|
+
processor=processor,
|
243
|
+
qk_norm="rms_norm",
|
244
|
+
eps=1e-6,
|
245
|
+
pre_only=True,
|
246
|
+
)
|
247
|
+
|
248
|
+
def forward(
|
249
|
+
self,
|
250
|
+
hidden_states: torch.Tensor,
|
251
|
+
temb: torch.Tensor,
|
252
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
253
|
+
attention_mask: Optional[torch.Tensor] = None,
|
254
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
255
|
+
) -> torch.Tensor:
|
256
|
+
residual = hidden_states
|
257
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
258
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
259
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
260
|
+
|
261
|
+
if attention_mask is not None:
|
262
|
+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
263
|
+
|
264
|
+
attn_output = self.attn(
|
265
|
+
hidden_states=norm_hidden_states,
|
266
|
+
image_rotary_emb=image_rotary_emb,
|
267
|
+
attention_mask=attention_mask,
|
268
|
+
**joint_attention_kwargs,
|
269
|
+
)
|
270
|
+
|
271
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
272
|
+
gate = gate.unsqueeze(1)
|
273
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
274
|
+
hidden_states = residual + hidden_states
|
275
|
+
if hidden_states.dtype == torch.float16:
|
276
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
277
|
+
|
278
|
+
return hidden_states
|
279
|
+
|
280
|
+
|
281
|
+
@maybe_allow_in_graph
|
282
|
+
class ChromaTransformerBlock(nn.Module):
|
283
|
+
def __init__(
|
284
|
+
self,
|
285
|
+
dim: int,
|
286
|
+
num_attention_heads: int,
|
287
|
+
attention_head_dim: int,
|
288
|
+
qk_norm: str = "rms_norm",
|
289
|
+
eps: float = 1e-6,
|
290
|
+
):
|
291
|
+
super().__init__()
|
292
|
+
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
293
|
+
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
294
|
+
|
295
|
+
self.attn = Attention(
|
296
|
+
query_dim=dim,
|
297
|
+
cross_attention_dim=None,
|
298
|
+
added_kv_proj_dim=dim,
|
299
|
+
dim_head=attention_head_dim,
|
300
|
+
heads=num_attention_heads,
|
301
|
+
out_dim=dim,
|
302
|
+
context_pre_only=False,
|
303
|
+
bias=True,
|
304
|
+
processor=FluxAttnProcessor2_0(),
|
305
|
+
qk_norm=qk_norm,
|
306
|
+
eps=eps,
|
307
|
+
)
|
308
|
+
|
309
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
310
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
311
|
+
|
312
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
313
|
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
314
|
+
|
315
|
+
def forward(
|
316
|
+
self,
|
317
|
+
hidden_states: torch.Tensor,
|
318
|
+
encoder_hidden_states: torch.Tensor,
|
319
|
+
temb: torch.Tensor,
|
320
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
321
|
+
attention_mask: Optional[torch.Tensor] = None,
|
322
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
323
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
324
|
+
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
325
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
|
326
|
+
|
327
|
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
328
|
+
encoder_hidden_states, emb=temb_txt
|
329
|
+
)
|
330
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
331
|
+
if attention_mask is not None:
|
332
|
+
attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
333
|
+
|
334
|
+
# Attention.
|
335
|
+
attention_outputs = self.attn(
|
336
|
+
hidden_states=norm_hidden_states,
|
337
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
338
|
+
image_rotary_emb=image_rotary_emb,
|
339
|
+
attention_mask=attention_mask,
|
340
|
+
**joint_attention_kwargs,
|
341
|
+
)
|
342
|
+
|
343
|
+
if len(attention_outputs) == 2:
|
344
|
+
attn_output, context_attn_output = attention_outputs
|
345
|
+
elif len(attention_outputs) == 3:
|
346
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
347
|
+
|
348
|
+
# Process attention outputs for the `hidden_states`.
|
349
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
350
|
+
hidden_states = hidden_states + attn_output
|
351
|
+
|
352
|
+
norm_hidden_states = self.norm2(hidden_states)
|
353
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
354
|
+
|
355
|
+
ff_output = self.ff(norm_hidden_states)
|
356
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
357
|
+
|
358
|
+
hidden_states = hidden_states + ff_output
|
359
|
+
if len(attention_outputs) == 3:
|
360
|
+
hidden_states = hidden_states + ip_attn_output
|
361
|
+
|
362
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
363
|
+
|
364
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
365
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
366
|
+
|
367
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
368
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
369
|
+
|
370
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
371
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
372
|
+
if encoder_hidden_states.dtype == torch.float16:
|
373
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
374
|
+
|
375
|
+
return encoder_hidden_states, hidden_states
|
376
|
+
|
377
|
+
|
378
|
+
class ChromaTransformer2DModel(
|
379
|
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
380
|
+
):
|
381
|
+
"""
|
382
|
+
The Transformer model introduced in Flux, modified for Chroma.
|
383
|
+
|
384
|
+
Reference: https://huggingface.co/lodestones/Chroma
|
385
|
+
|
386
|
+
Args:
|
387
|
+
patch_size (`int`, defaults to `1`):
|
388
|
+
Patch size to turn the input data into small patches.
|
389
|
+
in_channels (`int`, defaults to `64`):
|
390
|
+
The number of channels in the input.
|
391
|
+
out_channels (`int`, *optional*, defaults to `None`):
|
392
|
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
393
|
+
num_layers (`int`, defaults to `19`):
|
394
|
+
The number of layers of dual stream DiT blocks to use.
|
395
|
+
num_single_layers (`int`, defaults to `38`):
|
396
|
+
The number of layers of single stream DiT blocks to use.
|
397
|
+
attention_head_dim (`int`, defaults to `128`):
|
398
|
+
The number of dimensions to use for each attention head.
|
399
|
+
num_attention_heads (`int`, defaults to `24`):
|
400
|
+
The number of attention heads to use.
|
401
|
+
joint_attention_dim (`int`, defaults to `4096`):
|
402
|
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
403
|
+
`encoder_hidden_states`).
|
404
|
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
405
|
+
The dimensions to use for the rotary positional embeddings.
|
406
|
+
"""
|
407
|
+
|
408
|
+
_supports_gradient_checkpointing = True
|
409
|
+
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
410
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
411
|
+
|
412
|
+
@register_to_config
|
413
|
+
def __init__(
|
414
|
+
self,
|
415
|
+
patch_size: int = 1,
|
416
|
+
in_channels: int = 64,
|
417
|
+
out_channels: Optional[int] = None,
|
418
|
+
num_layers: int = 19,
|
419
|
+
num_single_layers: int = 38,
|
420
|
+
attention_head_dim: int = 128,
|
421
|
+
num_attention_heads: int = 24,
|
422
|
+
joint_attention_dim: int = 4096,
|
423
|
+
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
|
424
|
+
approximator_num_channels: int = 64,
|
425
|
+
approximator_hidden_dim: int = 5120,
|
426
|
+
approximator_layers: int = 5,
|
427
|
+
):
|
428
|
+
super().__init__()
|
429
|
+
self.out_channels = out_channels or in_channels
|
430
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
431
|
+
|
432
|
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
433
|
+
|
434
|
+
self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
|
435
|
+
num_channels=approximator_num_channels // 4,
|
436
|
+
out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
|
437
|
+
)
|
438
|
+
self.distilled_guidance_layer = ChromaApproximator(
|
439
|
+
in_dim=approximator_num_channels,
|
440
|
+
out_dim=self.inner_dim,
|
441
|
+
hidden_dim=approximator_hidden_dim,
|
442
|
+
n_layers=approximator_layers,
|
443
|
+
)
|
444
|
+
|
445
|
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
446
|
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
|
447
|
+
|
448
|
+
self.transformer_blocks = nn.ModuleList(
|
449
|
+
[
|
450
|
+
ChromaTransformerBlock(
|
451
|
+
dim=self.inner_dim,
|
452
|
+
num_attention_heads=num_attention_heads,
|
453
|
+
attention_head_dim=attention_head_dim,
|
454
|
+
)
|
455
|
+
for _ in range(num_layers)
|
456
|
+
]
|
457
|
+
)
|
458
|
+
|
459
|
+
self.single_transformer_blocks = nn.ModuleList(
|
460
|
+
[
|
461
|
+
ChromaSingleTransformerBlock(
|
462
|
+
dim=self.inner_dim,
|
463
|
+
num_attention_heads=num_attention_heads,
|
464
|
+
attention_head_dim=attention_head_dim,
|
465
|
+
)
|
466
|
+
for _ in range(num_single_layers)
|
467
|
+
]
|
468
|
+
)
|
469
|
+
|
470
|
+
self.norm_out = ChromaAdaLayerNormContinuousPruned(
|
471
|
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
472
|
+
)
|
473
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
474
|
+
|
475
|
+
self.gradient_checkpointing = False
|
476
|
+
|
477
|
+
@property
|
478
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
479
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
480
|
+
r"""
|
481
|
+
Returns:
|
482
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
483
|
+
indexed by its weight name.
|
484
|
+
"""
|
485
|
+
# set recursively
|
486
|
+
processors = {}
|
487
|
+
|
488
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
489
|
+
if hasattr(module, "get_processor"):
|
490
|
+
processors[f"{name}.processor"] = module.get_processor()
|
491
|
+
|
492
|
+
for sub_name, child in module.named_children():
|
493
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
494
|
+
|
495
|
+
return processors
|
496
|
+
|
497
|
+
for name, module in self.named_children():
|
498
|
+
fn_recursive_add_processors(name, module, processors)
|
499
|
+
|
500
|
+
return processors
|
501
|
+
|
502
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
503
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
504
|
+
r"""
|
505
|
+
Sets the attention processor to use to compute attention.
|
506
|
+
|
507
|
+
Parameters:
|
508
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
509
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
510
|
+
for **all** `Attention` layers.
|
511
|
+
|
512
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
513
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
514
|
+
|
515
|
+
"""
|
516
|
+
count = len(self.attn_processors.keys())
|
517
|
+
|
518
|
+
if isinstance(processor, dict) and len(processor) != count:
|
519
|
+
raise ValueError(
|
520
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
521
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
522
|
+
)
|
523
|
+
|
524
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
525
|
+
if hasattr(module, "set_processor"):
|
526
|
+
if not isinstance(processor, dict):
|
527
|
+
module.set_processor(processor)
|
528
|
+
else:
|
529
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
530
|
+
|
531
|
+
for sub_name, child in module.named_children():
|
532
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
533
|
+
|
534
|
+
for name, module in self.named_children():
|
535
|
+
fn_recursive_attn_processor(name, module, processor)
|
536
|
+
|
537
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
538
|
+
def fuse_qkv_projections(self):
|
539
|
+
"""
|
540
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
541
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
542
|
+
|
543
|
+
<Tip warning={true}>
|
544
|
+
|
545
|
+
This API is 🧪 experimental.
|
546
|
+
|
547
|
+
</Tip>
|
548
|
+
"""
|
549
|
+
self.original_attn_processors = None
|
550
|
+
|
551
|
+
for _, attn_processor in self.attn_processors.items():
|
552
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
553
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
554
|
+
|
555
|
+
self.original_attn_processors = self.attn_processors
|
556
|
+
|
557
|
+
for module in self.modules():
|
558
|
+
if isinstance(module, Attention):
|
559
|
+
module.fuse_projections(fuse=True)
|
560
|
+
|
561
|
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
562
|
+
|
563
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
564
|
+
def unfuse_qkv_projections(self):
|
565
|
+
"""Disables the fused QKV projection if enabled.
|
566
|
+
|
567
|
+
<Tip warning={true}>
|
568
|
+
|
569
|
+
This API is 🧪 experimental.
|
570
|
+
|
571
|
+
</Tip>
|
572
|
+
|
573
|
+
"""
|
574
|
+
if self.original_attn_processors is not None:
|
575
|
+
self.set_attn_processor(self.original_attn_processors)
|
576
|
+
|
577
|
+
def forward(
|
578
|
+
self,
|
579
|
+
hidden_states: torch.Tensor,
|
580
|
+
encoder_hidden_states: torch.Tensor = None,
|
581
|
+
timestep: torch.LongTensor = None,
|
582
|
+
img_ids: torch.Tensor = None,
|
583
|
+
txt_ids: torch.Tensor = None,
|
584
|
+
attention_mask: torch.Tensor = None,
|
585
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
586
|
+
controlnet_block_samples=None,
|
587
|
+
controlnet_single_block_samples=None,
|
588
|
+
return_dict: bool = True,
|
589
|
+
controlnet_blocks_repeat: bool = False,
|
590
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
591
|
+
"""
|
592
|
+
The [`FluxTransformer2DModel`] forward method.
|
593
|
+
|
594
|
+
Args:
|
595
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
596
|
+
Input `hidden_states`.
|
597
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
598
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
599
|
+
timestep ( `torch.LongTensor`):
|
600
|
+
Used to indicate denoising step.
|
601
|
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
602
|
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
603
|
+
joint_attention_kwargs (`dict`, *optional*):
|
604
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
605
|
+
`self.processor` in
|
606
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
607
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
608
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
609
|
+
tuple.
|
610
|
+
|
611
|
+
Returns:
|
612
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
613
|
+
`tuple` where the first element is the sample tensor.
|
614
|
+
"""
|
615
|
+
if joint_attention_kwargs is not None:
|
616
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
617
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
618
|
+
else:
|
619
|
+
lora_scale = 1.0
|
620
|
+
|
621
|
+
if USE_PEFT_BACKEND:
|
622
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
623
|
+
scale_lora_layers(self, lora_scale)
|
624
|
+
else:
|
625
|
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
626
|
+
logger.warning(
|
627
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
628
|
+
)
|
629
|
+
|
630
|
+
hidden_states = self.x_embedder(hidden_states)
|
631
|
+
|
632
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
633
|
+
|
634
|
+
input_vec = self.time_text_embed(timestep)
|
635
|
+
pooled_temb = self.distilled_guidance_layer(input_vec)
|
636
|
+
|
637
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
638
|
+
|
639
|
+
if txt_ids.ndim == 3:
|
640
|
+
logger.warning(
|
641
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
642
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
643
|
+
)
|
644
|
+
txt_ids = txt_ids[0]
|
645
|
+
if img_ids.ndim == 3:
|
646
|
+
logger.warning(
|
647
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
648
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
649
|
+
)
|
650
|
+
img_ids = img_ids[0]
|
651
|
+
|
652
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
653
|
+
image_rotary_emb = self.pos_embed(ids)
|
654
|
+
|
655
|
+
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
656
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
657
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
658
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
659
|
+
|
660
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
661
|
+
img_offset = 3 * len(self.single_transformer_blocks)
|
662
|
+
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
663
|
+
img_modulation = img_offset + 6 * index_block
|
664
|
+
text_modulation = txt_offset + 6 * index_block
|
665
|
+
temb = torch.cat(
|
666
|
+
(
|
667
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
668
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
669
|
+
),
|
670
|
+
dim=1,
|
671
|
+
)
|
672
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
673
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
674
|
+
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
|
675
|
+
)
|
676
|
+
|
677
|
+
else:
|
678
|
+
encoder_hidden_states, hidden_states = block(
|
679
|
+
hidden_states=hidden_states,
|
680
|
+
encoder_hidden_states=encoder_hidden_states,
|
681
|
+
temb=temb,
|
682
|
+
image_rotary_emb=image_rotary_emb,
|
683
|
+
attention_mask=attention_mask,
|
684
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
685
|
+
)
|
686
|
+
|
687
|
+
# controlnet residual
|
688
|
+
if controlnet_block_samples is not None:
|
689
|
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
690
|
+
interval_control = int(np.ceil(interval_control))
|
691
|
+
# For Xlabs ControlNet.
|
692
|
+
if controlnet_blocks_repeat:
|
693
|
+
hidden_states = (
|
694
|
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
695
|
+
)
|
696
|
+
else:
|
697
|
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
698
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
699
|
+
|
700
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
701
|
+
start_idx = 3 * index_block
|
702
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
703
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
704
|
+
hidden_states = self._gradient_checkpointing_func(
|
705
|
+
block,
|
706
|
+
hidden_states,
|
707
|
+
temb,
|
708
|
+
image_rotary_emb,
|
709
|
+
)
|
710
|
+
|
711
|
+
else:
|
712
|
+
hidden_states = block(
|
713
|
+
hidden_states=hidden_states,
|
714
|
+
temb=temb,
|
715
|
+
image_rotary_emb=image_rotary_emb,
|
716
|
+
attention_mask=attention_mask,
|
717
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
718
|
+
)
|
719
|
+
|
720
|
+
# controlnet residual
|
721
|
+
if controlnet_single_block_samples is not None:
|
722
|
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
723
|
+
interval_control = int(np.ceil(interval_control))
|
724
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
725
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
726
|
+
+ controlnet_single_block_samples[index_block // interval_control]
|
727
|
+
)
|
728
|
+
|
729
|
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
730
|
+
|
731
|
+
temb = pooled_temb[:, -2:]
|
732
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
733
|
+
output = self.proj_out(hidden_states)
|
734
|
+
|
735
|
+
if USE_PEFT_BACKEND:
|
736
|
+
# remove `lora_scale` from each PEFT layer
|
737
|
+
unscale_lora_layers(self, lora_scale)
|
738
|
+
|
739
|
+
if not return_dict:
|
740
|
+
return (output,)
|
741
|
+
|
742
|
+
return Transformer2DModelOutput(sample=output)
|