diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import torch.nn as nn
|
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
23
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
-
from ...utils import USE_PEFT_BACKEND,
|
24
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
25
|
from ..attention import JointTransformerBlock
|
26
26
|
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
27
27
|
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
|
|
40
40
|
|
41
41
|
|
42
42
|
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
43
|
+
r"""
|
44
|
+
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
45
|
+
|
46
|
+
Parameters:
|
47
|
+
sample_size (`int`, defaults to `128`):
|
48
|
+
The width/height of the latents. This is fixed during training since it is used to learn a number of
|
49
|
+
position embeddings.
|
50
|
+
patch_size (`int`, defaults to `2`):
|
51
|
+
Patch size to turn the input data into small patches.
|
52
|
+
in_channels (`int`, defaults to `16`):
|
53
|
+
The number of latent channels in the input.
|
54
|
+
num_layers (`int`, defaults to `18`):
|
55
|
+
The number of layers of transformer blocks to use.
|
56
|
+
attention_head_dim (`int`, defaults to `64`):
|
57
|
+
The number of channels in each head.
|
58
|
+
num_attention_heads (`int`, defaults to `18`):
|
59
|
+
The number of heads to use for multi-head attention.
|
60
|
+
joint_attention_dim (`int`, defaults to `4096`):
|
61
|
+
The embedding dimension to use for joint text-image attention.
|
62
|
+
caption_projection_dim (`int`, defaults to `1152`):
|
63
|
+
The embedding dimension of caption embeddings.
|
64
|
+
pooled_projection_dim (`int`, defaults to `2048`):
|
65
|
+
The embedding dimension of pooled text projections.
|
66
|
+
out_channels (`int`, defaults to `16`):
|
67
|
+
The number of latent channels in the output.
|
68
|
+
pos_embed_max_size (`int`, defaults to `96`):
|
69
|
+
The maximum latent height/width of positional embeddings.
|
70
|
+
extra_conditioning_channels (`int`, defaults to `0`):
|
71
|
+
The number of extra channels to use for conditioning for patch embedding.
|
72
|
+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
|
73
|
+
The number of dual-stream transformer blocks to use.
|
74
|
+
qk_norm (`str`, *optional*, defaults to `None`):
|
75
|
+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
|
76
|
+
pos_embed_type (`str`, defaults to `"sincos"`):
|
77
|
+
The type of positional embedding to use. Choose between `"sincos"` and `None`.
|
78
|
+
use_pos_embed (`bool`, defaults to `True`):
|
79
|
+
Whether to use positional embeddings.
|
80
|
+
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
|
81
|
+
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
|
82
|
+
config value of the ControlNet model.
|
83
|
+
"""
|
84
|
+
|
43
85
|
_supports_gradient_checkpointing = True
|
44
86
|
|
45
87
|
@register_to_config
|
@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
93
135
|
JointTransformerBlock(
|
94
136
|
dim=self.inner_dim,
|
95
137
|
num_attention_heads=num_attention_heads,
|
96
|
-
attention_head_dim=
|
138
|
+
attention_head_dim=attention_head_dim,
|
97
139
|
context_pre_only=False,
|
98
140
|
qk_norm=qk_norm,
|
99
141
|
use_dual_attention=True if i in dual_attention_layers else False,
|
@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
108
150
|
SD3SingleTransformerBlock(
|
109
151
|
dim=self.inner_dim,
|
110
152
|
num_attention_heads=num_attention_heads,
|
111
|
-
attention_head_dim=
|
153
|
+
attention_head_dim=attention_head_dim,
|
112
154
|
)
|
113
155
|
for _ in range(num_layers)
|
114
156
|
]
|
@@ -262,10 +304,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
262
304
|
if self.original_attn_processors is not None:
|
263
305
|
self.set_attn_processor(self.original_attn_processors)
|
264
306
|
|
265
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
266
|
-
if hasattr(module, "gradient_checkpointing"):
|
267
|
-
module.gradient_checkpointing = value
|
268
|
-
|
269
307
|
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
|
270
308
|
# we should have handled this in conversion script
|
271
309
|
def _get_pos_embed_from_transformer(self, transformer):
|
@@ -301,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
301
339
|
|
302
340
|
def forward(
|
303
341
|
self,
|
304
|
-
hidden_states: torch.
|
342
|
+
hidden_states: torch.Tensor,
|
305
343
|
controlnet_cond: torch.Tensor,
|
306
344
|
conditioning_scale: float = 1.0,
|
307
|
-
encoder_hidden_states: torch.
|
308
|
-
pooled_projections: torch.
|
345
|
+
encoder_hidden_states: torch.Tensor = None,
|
346
|
+
pooled_projections: torch.Tensor = None,
|
309
347
|
timestep: torch.LongTensor = None,
|
310
348
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
311
349
|
return_dict: bool = True,
|
312
|
-
) -> Union[torch.
|
350
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
313
351
|
"""
|
314
352
|
The [`SD3Transformer2DModel`] forward method.
|
315
353
|
|
316
354
|
Args:
|
317
|
-
hidden_states (`torch.
|
355
|
+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
|
318
356
|
Input `hidden_states`.
|
319
357
|
controlnet_cond (`torch.Tensor`):
|
320
358
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
321
359
|
conditioning_scale (`float`, defaults to `1.0`):
|
322
360
|
The scale factor for ControlNet outputs.
|
323
|
-
encoder_hidden_states (`torch.
|
361
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
|
324
362
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
325
|
-
pooled_projections (`torch.
|
363
|
+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
326
364
|
from the embeddings of input conditions.
|
327
365
|
timestep ( `torch.LongTensor`):
|
328
366
|
Used to indicate denoising step.
|
@@ -382,30 +420,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
|
382
420
|
|
383
421
|
for block in self.transformer_blocks:
|
384
422
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
385
|
-
|
386
|
-
def create_custom_forward(module, return_dict=None):
|
387
|
-
def custom_forward(*inputs):
|
388
|
-
if return_dict is not None:
|
389
|
-
return module(*inputs, return_dict=return_dict)
|
390
|
-
else:
|
391
|
-
return module(*inputs)
|
392
|
-
|
393
|
-
return custom_forward
|
394
|
-
|
395
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
396
423
|
if self.context_embedder is not None:
|
397
|
-
encoder_hidden_states, hidden_states =
|
398
|
-
|
424
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
425
|
+
block,
|
399
426
|
hidden_states,
|
400
427
|
encoder_hidden_states,
|
401
428
|
temb,
|
402
|
-
**ckpt_kwargs,
|
403
429
|
)
|
404
430
|
else:
|
405
431
|
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
406
|
-
hidden_states =
|
407
|
-
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
|
408
|
-
)
|
432
|
+
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
|
409
433
|
|
410
434
|
else:
|
411
435
|
if self.context_embedder is not None:
|
@@ -455,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin):
|
|
455
479
|
|
456
480
|
def forward(
|
457
481
|
self,
|
458
|
-
hidden_states: torch.
|
482
|
+
hidden_states: torch.Tensor,
|
459
483
|
controlnet_cond: List[torch.tensor],
|
460
484
|
conditioning_scale: List[float],
|
461
|
-
pooled_projections: torch.
|
462
|
-
encoder_hidden_states: torch.
|
485
|
+
pooled_projections: torch.Tensor,
|
486
|
+
encoder_hidden_states: torch.Tensor = None,
|
463
487
|
timestep: torch.LongTensor = None,
|
464
488
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
465
489
|
return_dict: bool = True,
|
@@ -590,10 +590,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
590
590
|
for module in self.children():
|
591
591
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
592
592
|
|
593
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
594
|
-
if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
|
595
|
-
module.gradient_checkpointing = value
|
596
|
-
|
597
593
|
def forward(
|
598
594
|
self,
|
599
595
|
sample: torch.Tensor,
|
@@ -671,10 +667,11 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
671
667
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
672
668
|
# This would be a good case for the `match` statement (Python 3.10+)
|
673
669
|
is_mps = sample.device.type == "mps"
|
670
|
+
is_npu = sample.device.type == "npu"
|
674
671
|
if isinstance(timestep, float):
|
675
|
-
dtype = torch.float32 if is_mps else torch.float64
|
672
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
676
673
|
else:
|
677
|
-
dtype = torch.int32 if is_mps else torch.int64
|
674
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
678
675
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
679
676
|
elif len(timesteps.shape) == 0:
|
680
677
|
timesteps = timesteps[None].to(sample.device)
|
@@ -690,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
690
687
|
t_emb = t_emb.to(dtype=sample.dtype)
|
691
688
|
|
692
689
|
emb = self.time_embedding(t_emb, timestep_cond)
|
693
|
-
emb = emb.repeat_interleave(sample_num_frames, dim=0)
|
690
|
+
emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
|
694
691
|
|
695
692
|
# 2. pre-process
|
696
693
|
batch_size, channels, num_frames, height, width = sample.shape
|
@@ -29,8 +29,6 @@ from ..attention_processor import (
|
|
29
29
|
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
30
30
|
from ..modeling_utils import ModelMixin
|
31
31
|
from ..unets.unet_2d_blocks import (
|
32
|
-
CrossAttnDownBlock2D,
|
33
|
-
DownBlock2D,
|
34
32
|
UNetMidBlock2DCrossAttn,
|
35
33
|
get_down_block,
|
36
34
|
)
|
@@ -599,10 +597,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
599
597
|
for module in self.children():
|
600
598
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
601
599
|
|
602
|
-
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
603
|
-
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
604
|
-
module.gradient_checkpointing = value
|
605
|
-
|
606
600
|
def forward(
|
607
601
|
self,
|
608
602
|
sample: torch.Tensor,
|
@@ -611,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
611
605
|
controlnet_cond: List[torch.Tensor],
|
612
606
|
control_type: torch.Tensor,
|
613
607
|
control_type_idx: List[int],
|
614
|
-
conditioning_scale: float = 1.0,
|
608
|
+
conditioning_scale: Union[float, List[float]] = 1.0,
|
615
609
|
class_labels: Optional[torch.Tensor] = None,
|
616
610
|
timestep_cond: Optional[torch.Tensor] = None,
|
617
611
|
attention_mask: Optional[torch.Tensor] = None,
|
618
612
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
619
613
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
614
|
+
from_multi: bool = False,
|
620
615
|
guess_mode: bool = False,
|
621
616
|
return_dict: bool = True,
|
622
617
|
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
@@ -653,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
653
648
|
Additional conditions for the Stable Diffusion XL UNet.
|
654
649
|
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
655
650
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
651
|
+
from_multi (`bool`, defaults to `False`):
|
652
|
+
Use standard scaling when called from `MultiControlNetUnionModel`.
|
656
653
|
guess_mode (`bool`, defaults to `False`):
|
657
654
|
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
658
655
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
@@ -664,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
664
661
|
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
665
662
|
returned where the first element is the sample tensor.
|
666
663
|
"""
|
664
|
+
if isinstance(conditioning_scale, float):
|
665
|
+
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
|
666
|
+
|
667
667
|
# check channel order
|
668
668
|
channel_order = self.config.controlnet_conditioning_channel_order
|
669
669
|
|
@@ -681,10 +681,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
681
681
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
682
682
|
# This would be a good case for the `match` statement (Python 3.10+)
|
683
683
|
is_mps = sample.device.type == "mps"
|
684
|
+
is_npu = sample.device.type == "npu"
|
684
685
|
if isinstance(timestep, float):
|
685
|
-
dtype = torch.float32 if is_mps else torch.float64
|
686
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
686
687
|
else:
|
687
|
-
dtype = torch.int32 if is_mps else torch.int64
|
688
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
688
689
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
689
690
|
elif len(timesteps.shape) == 0:
|
690
691
|
timesteps = timesteps[None].to(sample.device)
|
@@ -747,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
747
748
|
inputs = []
|
748
749
|
condition_list = []
|
749
750
|
|
750
|
-
for cond, control_idx in zip(controlnet_cond, control_type_idx):
|
751
|
+
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
|
751
752
|
condition = self.controlnet_cond_embedding(cond)
|
752
753
|
feat_seq = torch.mean(condition, dim=(2, 3))
|
753
754
|
feat_seq = feat_seq + self.task_embedding[control_idx]
|
754
|
-
|
755
|
-
|
755
|
+
if from_multi:
|
756
|
+
inputs.append(feat_seq.unsqueeze(1))
|
757
|
+
condition_list.append(condition)
|
758
|
+
else:
|
759
|
+
inputs.append(feat_seq.unsqueeze(1) * scale)
|
760
|
+
condition_list.append(condition * scale)
|
756
761
|
|
757
762
|
condition = sample
|
758
763
|
feat_seq = torch.mean(condition, dim=(2, 3))
|
@@ -764,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
764
769
|
x = layer(x)
|
765
770
|
|
766
771
|
controlnet_cond_fuser = sample * 0.0
|
767
|
-
for idx, condition in enumerate(condition_list[:-1]):
|
772
|
+
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
768
773
|
alpha = self.spatial_ch_projs(x[:, idx])
|
769
774
|
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
770
|
-
|
775
|
+
if from_multi:
|
776
|
+
controlnet_cond_fuser += condition + alpha
|
777
|
+
else:
|
778
|
+
controlnet_cond_fuser += condition + alpha * scale
|
771
779
|
|
772
780
|
sample = sample + controlnet_cond_fuser
|
773
781
|
|
@@ -811,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
811
819
|
# 6. scaling
|
812
820
|
if guess_mode and not self.config.global_pool_conditions:
|
813
821
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
814
|
-
|
822
|
+
if from_multi:
|
823
|
+
scales = scales * conditioning_scale[0]
|
815
824
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
816
825
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
817
|
-
|
818
|
-
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
819
|
-
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
826
|
+
elif from_multi:
|
827
|
+
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
828
|
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
820
829
|
|
821
830
|
if self.config.global_pool_conditions:
|
822
831
|
down_block_res_samples = [
|
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
|
20
20
|
from torch import Tensor, nn
|
21
21
|
|
22
22
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
-
from ...utils import BaseOutput,
|
23
|
+
from ...utils import BaseOutput, logging
|
24
24
|
from ...utils.torch_utils import apply_freeu
|
25
25
|
from ..attention_processor import (
|
26
26
|
ADDED_KV_ATTENTION_PROCESSORS,
|
@@ -864,10 +864,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
864
864
|
for u in self.up_blocks:
|
865
865
|
u.freeze_base_params()
|
866
866
|
|
867
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
868
|
-
if hasattr(module, "gradient_checkpointing"):
|
869
|
-
module.gradient_checkpointing = value
|
870
|
-
|
871
867
|
@property
|
872
868
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
873
869
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
@@ -1088,10 +1084,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
1088
1084
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
1089
1085
|
# This would be a good case for the `match` statement (Python 3.10+)
|
1090
1086
|
is_mps = sample.device.type == "mps"
|
1087
|
+
is_npu = sample.device.type == "npu"
|
1091
1088
|
if isinstance(timestep, float):
|
1092
|
-
dtype = torch.float32 if is_mps else torch.float64
|
1089
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
1093
1090
|
else:
|
1094
|
-
dtype = torch.int32 if is_mps else torch.int64
|
1091
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
1095
1092
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
1096
1093
|
elif len(timesteps.shape) == 0:
|
1097
1094
|
timesteps = timesteps[None].to(sample.device)
|
@@ -1449,15 +1446,6 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1449
1446
|
base_blocks = list(zip(self.base_resnets, self.base_attentions))
|
1450
1447
|
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
|
1451
1448
|
|
1452
|
-
def create_custom_forward(module, return_dict=None):
|
1453
|
-
def custom_forward(*inputs):
|
1454
|
-
if return_dict is not None:
|
1455
|
-
return module(*inputs, return_dict=return_dict)
|
1456
|
-
else:
|
1457
|
-
return module(*inputs)
|
1458
|
-
|
1459
|
-
return custom_forward
|
1460
|
-
|
1461
1449
|
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
|
1462
1450
|
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
|
1463
1451
|
):
|
@@ -1467,13 +1455,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1467
1455
|
|
1468
1456
|
# apply base subblock
|
1469
1457
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1470
|
-
|
1471
|
-
h_base = torch.utils.checkpoint.checkpoint(
|
1472
|
-
create_custom_forward(b_res),
|
1473
|
-
h_base,
|
1474
|
-
temb,
|
1475
|
-
**ckpt_kwargs,
|
1476
|
-
)
|
1458
|
+
h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
|
1477
1459
|
else:
|
1478
1460
|
h_base = b_res(h_base, temb)
|
1479
1461
|
|
@@ -1490,13 +1472,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
|
1490
1472
|
# apply ctrl subblock
|
1491
1473
|
if apply_control:
|
1492
1474
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1493
|
-
|
1494
|
-
h_ctrl = torch.utils.checkpoint.checkpoint(
|
1495
|
-
create_custom_forward(c_res),
|
1496
|
-
h_ctrl,
|
1497
|
-
temb,
|
1498
|
-
**ckpt_kwargs,
|
1499
|
-
)
|
1475
|
+
h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
|
1500
1476
|
else:
|
1501
1477
|
h_ctrl = c_res(h_ctrl, temb)
|
1502
1478
|
if c_attn is not None:
|
@@ -1861,15 +1837,6 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1861
1837
|
and getattr(self, "b2", None)
|
1862
1838
|
)
|
1863
1839
|
|
1864
|
-
def create_custom_forward(module, return_dict=None):
|
1865
|
-
def custom_forward(*inputs):
|
1866
|
-
if return_dict is not None:
|
1867
|
-
return module(*inputs, return_dict=return_dict)
|
1868
|
-
else:
|
1869
|
-
return module(*inputs)
|
1870
|
-
|
1871
|
-
return custom_forward
|
1872
|
-
|
1873
1840
|
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
|
1874
1841
|
# FreeU: Only operate on the first two stages
|
1875
1842
|
if is_freeu_enabled:
|
@@ -1899,13 +1866,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
|
1899
1866
|
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
|
1900
1867
|
|
1901
1868
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1902
|
-
|
1903
|
-
hidden_states = torch.utils.checkpoint.checkpoint(
|
1904
|
-
create_custom_forward(resnet),
|
1905
|
-
hidden_states,
|
1906
|
-
temb,
|
1907
|
-
**ckpt_kwargs,
|
1908
|
-
)
|
1869
|
+
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
|
1909
1870
|
else:
|
1910
1871
|
hidden_states = resnet(hidden_states, temb)
|
1911
1872
|
|
@@ -0,0 +1,196 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn
|
6
|
+
|
7
|
+
from ...models.controlnets.controlnet import ControlNetOutput
|
8
|
+
from ...models.controlnets.controlnet_union import ControlNetUnionModel
|
9
|
+
from ...models.modeling_utils import ModelMixin
|
10
|
+
from ...utils import logging
|
11
|
+
|
12
|
+
|
13
|
+
logger = logging.get_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class MultiControlNetUnionModel(ModelMixin):
|
17
|
+
r"""
|
18
|
+
Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union.
|
19
|
+
|
20
|
+
This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to
|
21
|
+
be compatible with `ControlNetUnionModel`.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
controlnets (`List[ControlNetUnionModel]`):
|
25
|
+
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
26
|
+
`ControlNetUnionModel` as a list.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
|
30
|
+
super().__init__()
|
31
|
+
self.nets = nn.ModuleList(controlnets)
|
32
|
+
|
33
|
+
def forward(
|
34
|
+
self,
|
35
|
+
sample: torch.Tensor,
|
36
|
+
timestep: Union[torch.Tensor, float, int],
|
37
|
+
encoder_hidden_states: torch.Tensor,
|
38
|
+
controlnet_cond: List[torch.tensor],
|
39
|
+
control_type: List[torch.Tensor],
|
40
|
+
control_type_idx: List[List[int]],
|
41
|
+
conditioning_scale: List[float],
|
42
|
+
class_labels: Optional[torch.Tensor] = None,
|
43
|
+
timestep_cond: Optional[torch.Tensor] = None,
|
44
|
+
attention_mask: Optional[torch.Tensor] = None,
|
45
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
46
|
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
47
|
+
guess_mode: bool = False,
|
48
|
+
return_dict: bool = True,
|
49
|
+
) -> Union[ControlNetOutput, Tuple]:
|
50
|
+
down_block_res_samples, mid_block_res_sample = None, None
|
51
|
+
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
|
52
|
+
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
|
53
|
+
):
|
54
|
+
if scale == 0.0:
|
55
|
+
continue
|
56
|
+
down_samples, mid_sample = controlnet(
|
57
|
+
sample=sample,
|
58
|
+
timestep=timestep,
|
59
|
+
encoder_hidden_states=encoder_hidden_states,
|
60
|
+
controlnet_cond=image,
|
61
|
+
control_type=ctype,
|
62
|
+
control_type_idx=ctype_idx,
|
63
|
+
conditioning_scale=scale,
|
64
|
+
class_labels=class_labels,
|
65
|
+
timestep_cond=timestep_cond,
|
66
|
+
attention_mask=attention_mask,
|
67
|
+
added_cond_kwargs=added_cond_kwargs,
|
68
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
69
|
+
from_multi=True,
|
70
|
+
guess_mode=guess_mode,
|
71
|
+
return_dict=return_dict,
|
72
|
+
)
|
73
|
+
|
74
|
+
# merge samples
|
75
|
+
if down_block_res_samples is None and mid_block_res_sample is None:
|
76
|
+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
77
|
+
else:
|
78
|
+
down_block_res_samples = [
|
79
|
+
samples_prev + samples_curr
|
80
|
+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
81
|
+
]
|
82
|
+
mid_block_res_sample += mid_sample
|
83
|
+
|
84
|
+
return down_block_res_samples, mid_block_res_sample
|
85
|
+
|
86
|
+
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
|
87
|
+
def save_pretrained(
|
88
|
+
self,
|
89
|
+
save_directory: Union[str, os.PathLike],
|
90
|
+
is_main_process: bool = True,
|
91
|
+
save_function: Callable = None,
|
92
|
+
safe_serialization: bool = True,
|
93
|
+
variant: Optional[str] = None,
|
94
|
+
):
|
95
|
+
"""
|
96
|
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
97
|
+
`[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method.
|
98
|
+
|
99
|
+
Arguments:
|
100
|
+
save_directory (`str` or `os.PathLike`):
|
101
|
+
Directory to which to save. Will be created if it doesn't exist.
|
102
|
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
103
|
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
104
|
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
105
|
+
the main process to avoid race conditions.
|
106
|
+
save_function (`Callable`):
|
107
|
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
108
|
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
109
|
+
`DIFFUSERS_SAVE_MODE`.
|
110
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
111
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
112
|
+
variant (`str`, *optional*):
|
113
|
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
114
|
+
"""
|
115
|
+
for idx, controlnet in enumerate(self.nets):
|
116
|
+
suffix = "" if idx == 0 else f"_{idx}"
|
117
|
+
controlnet.save_pretrained(
|
118
|
+
save_directory + suffix,
|
119
|
+
is_main_process=is_main_process,
|
120
|
+
save_function=save_function,
|
121
|
+
safe_serialization=safe_serialization,
|
122
|
+
variant=variant,
|
123
|
+
)
|
124
|
+
|
125
|
+
@classmethod
|
126
|
+
# Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
|
127
|
+
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
128
|
+
r"""
|
129
|
+
Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
|
130
|
+
|
131
|
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
132
|
+
the model, you should first set it back in training mode with `model.train()`.
|
133
|
+
|
134
|
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
135
|
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
136
|
+
task.
|
137
|
+
|
138
|
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
139
|
+
weights are discarded.
|
140
|
+
|
141
|
+
Parameters:
|
142
|
+
pretrained_model_path (`os.PathLike`):
|
143
|
+
A path to a *directory* containing model weights saved using
|
144
|
+
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
|
145
|
+
`./my_model_directory/controlnet`.
|
146
|
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
147
|
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
148
|
+
will be automatically derived from the model's weights.
|
149
|
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
150
|
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
151
|
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
152
|
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
153
|
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
154
|
+
same device.
|
155
|
+
|
156
|
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
157
|
+
more information about each option see [designing a device
|
158
|
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
159
|
+
max_memory (`Dict`, *optional*):
|
160
|
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
161
|
+
GPU and the available CPU RAM if unset.
|
162
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
163
|
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
164
|
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
165
|
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
166
|
+
setting this argument to `True` will raise an error.
|
167
|
+
variant (`str`, *optional*):
|
168
|
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
169
|
+
ignored when using `from_flax`.
|
170
|
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
171
|
+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
172
|
+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
173
|
+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
174
|
+
"""
|
175
|
+
idx = 0
|
176
|
+
controlnets = []
|
177
|
+
|
178
|
+
# load controlnet and append to list until no controlnet directory exists anymore
|
179
|
+
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
|
180
|
+
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
|
181
|
+
model_path_to_load = pretrained_model_path
|
182
|
+
while os.path.isdir(model_path_to_load):
|
183
|
+
controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs)
|
184
|
+
controlnets.append(controlnet)
|
185
|
+
|
186
|
+
idx += 1
|
187
|
+
model_path_to_load = pretrained_model_path + f"_{idx}"
|
188
|
+
|
189
|
+
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
|
190
|
+
|
191
|
+
if len(controlnets) == 0:
|
192
|
+
raise ValueError(
|
193
|
+
f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
194
|
+
)
|
195
|
+
|
196
|
+
return cls(controlnets)
|