diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +41 -40
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
CHANGED
@@ -29,7 +29,7 @@ from transformers import (
|
|
29
29
|
)
|
30
30
|
|
31
31
|
from ....image_processor import PipelineImageInput, VaeImageProcessor
|
32
|
-
from ....loaders import
|
32
|
+
from ....loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
33
33
|
from ....models import AutoencoderKL, UNet2DConditionModel
|
34
34
|
from ....models.attention_processor import Attention
|
35
35
|
from ....models.lora import adjust_lora_scale_text_encoder
|
@@ -60,14 +60,14 @@ class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
|
|
60
60
|
Output class for Stable Diffusion pipelines.
|
61
61
|
|
62
62
|
Args:
|
63
|
-
latents (`torch.
|
63
|
+
latents (`torch.Tensor`)
|
64
64
|
inverted latents tensor
|
65
65
|
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
66
66
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
67
67
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
68
68
|
"""
|
69
69
|
|
70
|
-
latents: torch.
|
70
|
+
latents: torch.Tensor
|
71
71
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
72
72
|
|
73
73
|
|
@@ -377,8 +377,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
377
377
|
num_images_per_prompt,
|
378
378
|
do_classifier_free_guidance,
|
379
379
|
negative_prompt=None,
|
380
|
-
prompt_embeds: Optional[torch.
|
381
|
-
negative_prompt_embeds: Optional[torch.
|
380
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
381
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
382
382
|
lora_scale: Optional[float] = None,
|
383
383
|
**kwargs,
|
384
384
|
):
|
@@ -410,8 +410,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
410
410
|
num_images_per_prompt,
|
411
411
|
do_classifier_free_guidance,
|
412
412
|
negative_prompt=None,
|
413
|
-
prompt_embeds: Optional[torch.
|
414
|
-
negative_prompt_embeds: Optional[torch.
|
413
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
414
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
415
415
|
lora_scale: Optional[float] = None,
|
416
416
|
clip_skip: Optional[int] = None,
|
417
417
|
):
|
@@ -431,10 +431,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
431
431
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
432
432
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
433
433
|
less than `1`).
|
434
|
-
prompt_embeds (`torch.
|
434
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
435
435
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
436
436
|
provided, text embeddings will be generated from `prompt` input argument.
|
437
|
-
negative_prompt_embeds (`torch.
|
437
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
438
438
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
439
439
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
440
440
|
argument.
|
@@ -446,7 +446,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
446
446
|
"""
|
447
447
|
# set lora scale so that monkey patched LoRA
|
448
448
|
# function of text encoder can correctly access it
|
449
|
-
if lora_scale is not None and isinstance(self,
|
449
|
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
450
450
|
self._lora_scale = lora_scale
|
451
451
|
|
452
452
|
# dynamically adjust the LoRA scale
|
@@ -578,9 +578,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
578
578
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
579
579
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
580
580
|
|
581
|
-
if
|
582
|
-
|
583
|
-
|
581
|
+
if self.text_encoder is not None:
|
582
|
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
583
|
+
# Retrieve the original scale by scaling back the LoRA layers
|
584
|
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
584
585
|
|
585
586
|
return prompt_embeds, negative_prompt_embeds
|
586
587
|
|
@@ -661,7 +662,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
661
662
|
|
662
663
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
663
664
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
664
|
-
shape = (
|
665
|
+
shape = (
|
666
|
+
batch_size,
|
667
|
+
num_channels_latents,
|
668
|
+
int(height) // self.vae_scale_factor,
|
669
|
+
int(width) // self.vae_scale_factor,
|
670
|
+
)
|
665
671
|
if isinstance(generator, list) and len(generator) != batch_size:
|
666
672
|
raise ValueError(
|
667
673
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
@@ -702,7 +708,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
702
708
|
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
|
703
709
|
|
704
710
|
@torch.no_grad()
|
705
|
-
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.
|
711
|
+
def get_embeds(self, prompt: List[str], batch_size: int = 16) -> torch.Tensor:
|
706
712
|
num_prompts = len(prompt)
|
707
713
|
embeds = []
|
708
714
|
for i in range(0, num_prompts, batch_size):
|
@@ -822,13 +828,13 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
822
828
|
num_images_per_prompt: Optional[int] = 1,
|
823
829
|
eta: float = 0.0,
|
824
830
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
825
|
-
latents: Optional[torch.
|
826
|
-
prompt_embeds: Optional[torch.
|
827
|
-
negative_prompt_embeds: Optional[torch.
|
831
|
+
latents: Optional[torch.Tensor] = None,
|
832
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
833
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
828
834
|
cross_attention_guidance_amount: float = 0.1,
|
829
835
|
output_type: Optional[str] = "pil",
|
830
836
|
return_dict: bool = True,
|
831
|
-
callback: Optional[Callable[[int, int, torch.
|
837
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
832
838
|
callback_steps: Optional[int] = 1,
|
833
839
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
834
840
|
clip_skip: Optional[int] = None,
|
@@ -871,14 +877,14 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
871
877
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
872
878
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
873
879
|
to make generation deterministic.
|
874
|
-
latents (`torch.
|
880
|
+
latents (`torch.Tensor`, *optional*):
|
875
881
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
876
882
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
877
883
|
tensor will ge generated by sampling using the supplied random `generator`.
|
878
|
-
prompt_embeds (`torch.
|
884
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
879
885
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
880
886
|
provided, text embeddings will be generated from `prompt` input argument.
|
881
|
-
negative_prompt_embeds (`torch.
|
887
|
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
882
888
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
883
889
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
884
890
|
argument.
|
@@ -892,7 +898,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
892
898
|
plain tuple.
|
893
899
|
callback (`Callable`, *optional*):
|
894
900
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
895
|
-
called with the following arguments: `callback(step: int, timestep: int, latents: torch.
|
901
|
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
896
902
|
callback_steps (`int`, *optional*, defaults to 1):
|
897
903
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
898
904
|
called at every step.
|
@@ -1107,12 +1113,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1107
1113
|
num_inference_steps: int = 50,
|
1108
1114
|
guidance_scale: float = 1,
|
1109
1115
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
1110
|
-
latents: Optional[torch.
|
1111
|
-
prompt_embeds: Optional[torch.
|
1116
|
+
latents: Optional[torch.Tensor] = None,
|
1117
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
1112
1118
|
cross_attention_guidance_amount: float = 0.1,
|
1113
1119
|
output_type: Optional[str] = "pil",
|
1114
1120
|
return_dict: bool = True,
|
1115
|
-
callback: Optional[Callable[[int, int, torch.
|
1121
|
+
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
1116
1122
|
callback_steps: Optional[int] = 1,
|
1117
1123
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1118
1124
|
lambda_auto_corr: float = 20.0,
|
@@ -1127,7 +1133,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1127
1133
|
prompt (`str` or `List[str]`, *optional*):
|
1128
1134
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
1129
1135
|
instead.
|
1130
|
-
image (`torch.
|
1136
|
+
image (`torch.Tensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
1131
1137
|
`Image`, or tensor representing an image batch which will be used for conditioning. Can also accept
|
1132
1138
|
image latents as `image`, if passing latents directly, it will not be encoded again.
|
1133
1139
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
@@ -1142,11 +1148,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1142
1148
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1143
1149
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1144
1150
|
to make generation deterministic.
|
1145
|
-
latents (`torch.
|
1151
|
+
latents (`torch.Tensor`, *optional*):
|
1146
1152
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1147
1153
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1148
1154
|
tensor will ge generated by sampling using the supplied random `generator`.
|
1149
|
-
prompt_embeds (`torch.
|
1155
|
+
prompt_embeds (`torch.Tensor`, *optional*):
|
1150
1156
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1151
1157
|
provided, text embeddings will be generated from `prompt` input argument.
|
1152
1158
|
cross_attention_guidance_amount (`float`, defaults to 0.1):
|
@@ -1159,7 +1165,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
|
|
1159
1165
|
plain tuple.
|
1160
1166
|
callback (`Callable`, *optional*):
|
1161
1167
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
1162
|
-
called with the following arguments: `callback(step: int, timestep: int, latents: torch.
|
1168
|
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
|
1163
1169
|
callback_steps (`int`, *optional*, defaults to 1):
|
1164
1170
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
1165
1171
|
called at every step.
|
@@ -363,6 +363,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
363
363
|
"""
|
364
364
|
|
365
365
|
_supports_gradient_checkpointing = True
|
366
|
+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
|
366
367
|
|
367
368
|
@register_to_config
|
368
369
|
def __init__(
|
@@ -531,7 +532,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
531
532
|
elif encoder_hid_dim_type == "text_image_proj":
|
532
533
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
533
534
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
534
|
-
# case when `addition_embed_type == "text_image_proj"` (
|
535
|
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
535
536
|
self.encoder_hid_proj = TextImageProjection(
|
536
537
|
text_embed_dim=encoder_hid_dim,
|
537
538
|
image_embed_dim=cross_attention_dim,
|
@@ -545,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
545
546
|
)
|
546
547
|
elif encoder_hid_dim_type is not None:
|
547
548
|
raise ValueError(
|
548
|
-
f"encoder_hid_dim_type
|
549
|
+
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
|
549
550
|
)
|
550
551
|
else:
|
551
552
|
self.encoder_hid_proj = None
|
@@ -591,7 +592,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
591
592
|
elif addition_embed_type == "text_image":
|
592
593
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
593
594
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
594
|
-
# case when `addition_embed_type == "text_image"` (
|
595
|
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
595
596
|
self.add_embedding = TextImageTimeEmbedding(
|
596
597
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
597
598
|
)
|
@@ -816,7 +817,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
816
817
|
positive_len = 768
|
817
818
|
if isinstance(cross_attention_dim, int):
|
818
819
|
positive_len = cross_attention_dim
|
819
|
-
elif isinstance(cross_attention_dim,
|
820
|
+
elif isinstance(cross_attention_dim, (list, tuple)):
|
820
821
|
positive_len = cross_attention_dim[0]
|
821
822
|
|
822
823
|
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
@@ -836,7 +837,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
836
837
|
|
837
838
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
838
839
|
if hasattr(module, "get_processor"):
|
839
|
-
processors[f"{name}.processor"] = module.get_processor(
|
840
|
+
processors[f"{name}.processor"] = module.get_processor()
|
840
841
|
|
841
842
|
for sub_name, child in module.named_children():
|
842
843
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
@@ -1000,8 +1001,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1000
1001
|
|
1001
1002
|
def fuse_qkv_projections(self):
|
1002
1003
|
"""
|
1003
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1004
|
-
|
1004
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
1005
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
1005
1006
|
|
1006
1007
|
<Tip warning={true}>
|
1007
1008
|
|
@@ -1047,7 +1048,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1047
1048
|
|
1048
1049
|
def forward(
|
1049
1050
|
self,
|
1050
|
-
sample: torch.
|
1051
|
+
sample: torch.Tensor,
|
1051
1052
|
timestep: Union[torch.Tensor, float, int],
|
1052
1053
|
encoder_hidden_states: torch.Tensor,
|
1053
1054
|
class_labels: Optional[torch.Tensor] = None,
|
@@ -1065,10 +1066,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1065
1066
|
The [`UNetFlatConditionModel`] forward method.
|
1066
1067
|
|
1067
1068
|
Args:
|
1068
|
-
sample (`torch.
|
1069
|
+
sample (`torch.Tensor`):
|
1069
1070
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
1070
|
-
timestep (`torch.
|
1071
|
-
encoder_hidden_states (`torch.
|
1071
|
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
1072
|
+
encoder_hidden_states (`torch.Tensor`):
|
1072
1073
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
1073
1074
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
1074
1075
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
@@ -1112,8 +1113,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1112
1113
|
|
1113
1114
|
Returns:
|
1114
1115
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
1115
|
-
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1116
|
-
a `tuple` is returned where the first element is the sample tensor.
|
1116
|
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
1117
|
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
1117
1118
|
"""
|
1118
1119
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
1119
1120
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
@@ -1257,7 +1258,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1257
1258
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
1258
1259
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
1259
1260
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
1260
|
-
#
|
1261
|
+
# Kandinsky 2.1 - style
|
1261
1262
|
if "image_embeds" not in added_cond_kwargs:
|
1262
1263
|
raise ValueError(
|
1263
1264
|
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
@@ -1589,12 +1590,12 @@ class DownBlockFlat(nn.Module):
|
|
1589
1590
|
self.gradient_checkpointing = False
|
1590
1591
|
|
1591
1592
|
def forward(
|
1592
|
-
self, hidden_states: torch.
|
1593
|
-
) -> Tuple[torch.
|
1593
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None
|
1594
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1594
1595
|
output_states = ()
|
1595
1596
|
|
1596
1597
|
for resnet in self.resnets:
|
1597
|
-
if
|
1598
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1598
1599
|
|
1599
1600
|
def create_custom_forward(module):
|
1600
1601
|
def custom_forward(*inputs):
|
@@ -1718,20 +1719,20 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1718
1719
|
|
1719
1720
|
def forward(
|
1720
1721
|
self,
|
1721
|
-
hidden_states: torch.
|
1722
|
-
temb: Optional[torch.
|
1723
|
-
encoder_hidden_states: Optional[torch.
|
1724
|
-
attention_mask: Optional[torch.
|
1722
|
+
hidden_states: torch.Tensor,
|
1723
|
+
temb: Optional[torch.Tensor] = None,
|
1724
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1725
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1725
1726
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1726
|
-
encoder_attention_mask: Optional[torch.
|
1727
|
-
additional_residuals: Optional[torch.
|
1728
|
-
) -> Tuple[torch.
|
1727
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1728
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
1729
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1729
1730
|
output_states = ()
|
1730
1731
|
|
1731
1732
|
blocks = list(zip(self.resnets, self.attentions))
|
1732
1733
|
|
1733
1734
|
for i, (resnet, attn) in enumerate(blocks):
|
1734
|
-
if
|
1735
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1735
1736
|
|
1736
1737
|
def create_custom_forward(module, return_dict=None):
|
1737
1738
|
def custom_forward(*inputs):
|
@@ -1836,13 +1837,13 @@ class UpBlockFlat(nn.Module):
|
|
1836
1837
|
|
1837
1838
|
def forward(
|
1838
1839
|
self,
|
1839
|
-
hidden_states: torch.
|
1840
|
-
res_hidden_states_tuple: Tuple[torch.
|
1841
|
-
temb: Optional[torch.
|
1840
|
+
hidden_states: torch.Tensor,
|
1841
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1842
|
+
temb: Optional[torch.Tensor] = None,
|
1842
1843
|
upsample_size: Optional[int] = None,
|
1843
1844
|
*args,
|
1844
1845
|
**kwargs,
|
1845
|
-
) -> torch.
|
1846
|
+
) -> torch.Tensor:
|
1846
1847
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
1848
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
1849
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1873,7 +1874,7 @@ class UpBlockFlat(nn.Module):
|
|
1873
1874
|
|
1874
1875
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1875
1876
|
|
1876
|
-
if
|
1877
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1877
1878
|
|
1878
1879
|
def create_custom_forward(module):
|
1879
1880
|
def custom_forward(*inputs):
|
@@ -1993,18 +1994,18 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1993
1994
|
|
1994
1995
|
def forward(
|
1995
1996
|
self,
|
1996
|
-
hidden_states: torch.
|
1997
|
-
res_hidden_states_tuple: Tuple[torch.
|
1998
|
-
temb: Optional[torch.
|
1999
|
-
encoder_hidden_states: Optional[torch.
|
1997
|
+
hidden_states: torch.Tensor,
|
1998
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1999
|
+
temb: Optional[torch.Tensor] = None,
|
2000
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2000
2001
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2001
2002
|
upsample_size: Optional[int] = None,
|
2002
|
-
attention_mask: Optional[torch.
|
2003
|
-
encoder_attention_mask: Optional[torch.
|
2004
|
-
) -> torch.
|
2003
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2004
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2005
|
+
) -> torch.Tensor:
|
2005
2006
|
if cross_attention_kwargs is not None:
|
2006
2007
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2007
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2008
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2008
2009
|
|
2009
2010
|
is_freeu_enabled = (
|
2010
2011
|
getattr(self, "s1", None)
|
@@ -2032,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
2032
2033
|
|
2033
2034
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2034
2035
|
|
2035
|
-
if
|
2036
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2036
2037
|
|
2037
2038
|
def create_custom_forward(module, return_dict=None):
|
2038
2039
|
def custom_forward(*inputs):
|
@@ -2103,8 +2104,8 @@ class UNetMidBlockFlat(nn.Module):
|
|
2103
2104
|
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
2104
2105
|
|
2105
2106
|
Returns:
|
2106
|
-
`torch.
|
2107
|
-
|
2107
|
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
|
2108
|
+
height, width)`.
|
2108
2109
|
|
2109
2110
|
"""
|
2110
2111
|
|
@@ -2222,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
|
|
2222
2223
|
self.attentions = nn.ModuleList(attentions)
|
2223
2224
|
self.resnets = nn.ModuleList(resnets)
|
2224
2225
|
|
2225
|
-
|
2226
|
+
self.gradient_checkpointing = False
|
2227
|
+
|
2228
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2226
2229
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2227
2230
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2228
|
-
if
|
2229
|
-
|
2230
|
-
|
2231
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2232
|
+
|
2233
|
+
def create_custom_forward(module, return_dict=None):
|
2234
|
+
def custom_forward(*inputs):
|
2235
|
+
if return_dict is not None:
|
2236
|
+
return module(*inputs, return_dict=return_dict)
|
2237
|
+
else:
|
2238
|
+
return module(*inputs)
|
2239
|
+
|
2240
|
+
return custom_forward
|
2241
|
+
|
2242
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2243
|
+
if attn is not None:
|
2244
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2245
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2246
|
+
create_custom_forward(resnet),
|
2247
|
+
hidden_states,
|
2248
|
+
temb,
|
2249
|
+
**ckpt_kwargs,
|
2250
|
+
)
|
2251
|
+
else:
|
2252
|
+
if attn is not None:
|
2253
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2254
|
+
hidden_states = resnet(hidden_states, temb)
|
2231
2255
|
|
2232
2256
|
return hidden_states
|
2233
2257
|
|
@@ -2238,6 +2262,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2238
2262
|
self,
|
2239
2263
|
in_channels: int,
|
2240
2264
|
temb_channels: int,
|
2265
|
+
out_channels: Optional[int] = None,
|
2241
2266
|
dropout: float = 0.0,
|
2242
2267
|
num_layers: int = 1,
|
2243
2268
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
@@ -2245,6 +2270,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2245
2270
|
resnet_time_scale_shift: str = "default",
|
2246
2271
|
resnet_act_fn: str = "swish",
|
2247
2272
|
resnet_groups: int = 32,
|
2273
|
+
resnet_groups_out: Optional[int] = None,
|
2248
2274
|
resnet_pre_norm: bool = True,
|
2249
2275
|
num_attention_heads: int = 1,
|
2250
2276
|
output_scale_factor: float = 1.0,
|
@@ -2256,6 +2282,10 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2256
2282
|
):
|
2257
2283
|
super().__init__()
|
2258
2284
|
|
2285
|
+
out_channels = out_channels or in_channels
|
2286
|
+
self.in_channels = in_channels
|
2287
|
+
self.out_channels = out_channels
|
2288
|
+
|
2259
2289
|
self.has_cross_attention = True
|
2260
2290
|
self.num_attention_heads = num_attention_heads
|
2261
2291
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
@@ -2264,14 +2294,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2264
2294
|
if isinstance(transformer_layers_per_block, int):
|
2265
2295
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
2266
2296
|
|
2297
|
+
resnet_groups_out = resnet_groups_out or resnet_groups
|
2298
|
+
|
2267
2299
|
# there is always at least one resnet
|
2268
2300
|
resnets = [
|
2269
2301
|
ResnetBlockFlat(
|
2270
2302
|
in_channels=in_channels,
|
2271
|
-
out_channels=
|
2303
|
+
out_channels=out_channels,
|
2272
2304
|
temb_channels=temb_channels,
|
2273
2305
|
eps=resnet_eps,
|
2274
2306
|
groups=resnet_groups,
|
2307
|
+
groups_out=resnet_groups_out,
|
2275
2308
|
dropout=dropout,
|
2276
2309
|
time_embedding_norm=resnet_time_scale_shift,
|
2277
2310
|
non_linearity=resnet_act_fn,
|
@@ -2286,11 +2319,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2286
2319
|
attentions.append(
|
2287
2320
|
Transformer2DModel(
|
2288
2321
|
num_attention_heads,
|
2289
|
-
|
2290
|
-
in_channels=
|
2322
|
+
out_channels // num_attention_heads,
|
2323
|
+
in_channels=out_channels,
|
2291
2324
|
num_layers=transformer_layers_per_block[i],
|
2292
2325
|
cross_attention_dim=cross_attention_dim,
|
2293
|
-
norm_num_groups=
|
2326
|
+
norm_num_groups=resnet_groups_out,
|
2294
2327
|
use_linear_projection=use_linear_projection,
|
2295
2328
|
upcast_attention=upcast_attention,
|
2296
2329
|
attention_type=attention_type,
|
@@ -2300,8 +2333,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2300
2333
|
attentions.append(
|
2301
2334
|
DualTransformer2DModel(
|
2302
2335
|
num_attention_heads,
|
2303
|
-
|
2304
|
-
in_channels=
|
2336
|
+
out_channels // num_attention_heads,
|
2337
|
+
in_channels=out_channels,
|
2305
2338
|
num_layers=1,
|
2306
2339
|
cross_attention_dim=cross_attention_dim,
|
2307
2340
|
norm_num_groups=resnet_groups,
|
@@ -2309,11 +2342,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2309
2342
|
)
|
2310
2343
|
resnets.append(
|
2311
2344
|
ResnetBlockFlat(
|
2312
|
-
in_channels=
|
2313
|
-
out_channels=
|
2345
|
+
in_channels=out_channels,
|
2346
|
+
out_channels=out_channels,
|
2314
2347
|
temb_channels=temb_channels,
|
2315
2348
|
eps=resnet_eps,
|
2316
|
-
groups=
|
2349
|
+
groups=resnet_groups_out,
|
2317
2350
|
dropout=dropout,
|
2318
2351
|
time_embedding_norm=resnet_time_scale_shift,
|
2319
2352
|
non_linearity=resnet_act_fn,
|
@@ -2329,20 +2362,20 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
2329
2362
|
|
2330
2363
|
def forward(
|
2331
2364
|
self,
|
2332
|
-
hidden_states: torch.
|
2333
|
-
temb: Optional[torch.
|
2334
|
-
encoder_hidden_states: Optional[torch.
|
2335
|
-
attention_mask: Optional[torch.
|
2365
|
+
hidden_states: torch.Tensor,
|
2366
|
+
temb: Optional[torch.Tensor] = None,
|
2367
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2368
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2336
2369
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2337
|
-
encoder_attention_mask: Optional[torch.
|
2338
|
-
) -> torch.
|
2370
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2371
|
+
) -> torch.Tensor:
|
2339
2372
|
if cross_attention_kwargs is not None:
|
2340
2373
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2341
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2374
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2342
2375
|
|
2343
2376
|
hidden_states = self.resnets[0](hidden_states, temb)
|
2344
2377
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
2345
|
-
if
|
2378
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2346
2379
|
|
2347
2380
|
def create_custom_forward(module, return_dict=None):
|
2348
2381
|
def custom_forward(*inputs):
|
@@ -2470,16 +2503,16 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
|
2470
2503
|
|
2471
2504
|
def forward(
|
2472
2505
|
self,
|
2473
|
-
hidden_states: torch.
|
2474
|
-
temb: Optional[torch.
|
2475
|
-
encoder_hidden_states: Optional[torch.
|
2476
|
-
attention_mask: Optional[torch.
|
2506
|
+
hidden_states: torch.Tensor,
|
2507
|
+
temb: Optional[torch.Tensor] = None,
|
2508
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2509
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2477
2510
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2478
|
-
encoder_attention_mask: Optional[torch.
|
2479
|
-
) -> torch.
|
2511
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2512
|
+
) -> torch.Tensor:
|
2480
2513
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2481
2514
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2482
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2515
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2483
2516
|
|
2484
2517
|
if attention_mask is None:
|
2485
2518
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|