diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +41 -40
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
|
|
561
561
|
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
|
562
562
|
|
563
563
|
Returns:
|
564
|
-
`torch.
|
564
|
+
`torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
|
565
565
|
`out_channels`.
|
566
566
|
"""
|
567
567
|
|
@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
|
|
582
582
|
)
|
583
583
|
self.fuse = nn.ReLU()
|
584
584
|
|
585
|
-
def forward(self, x: torch.
|
585
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
586
586
|
return self.fuse(self.conv(x) + self.skip(x))
|
587
587
|
|
588
588
|
|
@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
|
|
612
612
|
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
613
613
|
|
614
614
|
Returns:
|
615
|
-
`torch.
|
616
|
-
|
615
|
+
`torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
|
616
|
+
height, width)`.
|
617
617
|
|
618
618
|
"""
|
619
619
|
|
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
|
|
731
731
|
self.attentions = nn.ModuleList(attentions)
|
732
732
|
self.resnets = nn.ModuleList(resnets)
|
733
733
|
|
734
|
-
|
734
|
+
self.gradient_checkpointing = False
|
735
|
+
|
736
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
735
737
|
hidden_states = self.resnets[0](hidden_states, temb)
|
736
738
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
737
|
-
if
|
738
|
-
|
739
|
-
|
739
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
740
|
+
|
741
|
+
def create_custom_forward(module, return_dict=None):
|
742
|
+
def custom_forward(*inputs):
|
743
|
+
if return_dict is not None:
|
744
|
+
return module(*inputs, return_dict=return_dict)
|
745
|
+
else:
|
746
|
+
return module(*inputs)
|
747
|
+
|
748
|
+
return custom_forward
|
749
|
+
|
750
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
751
|
+
if attn is not None:
|
752
|
+
hidden_states = attn(hidden_states, temb=temb)
|
753
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
754
|
+
create_custom_forward(resnet),
|
755
|
+
hidden_states,
|
756
|
+
temb,
|
757
|
+
**ckpt_kwargs,
|
758
|
+
)
|
759
|
+
else:
|
760
|
+
if attn is not None:
|
761
|
+
hidden_states = attn(hidden_states, temb=temb)
|
762
|
+
hidden_states = resnet(hidden_states, temb)
|
740
763
|
|
741
764
|
return hidden_states
|
742
765
|
|
@@ -746,6 +769,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
746
769
|
self,
|
747
770
|
in_channels: int,
|
748
771
|
temb_channels: int,
|
772
|
+
out_channels: Optional[int] = None,
|
749
773
|
dropout: float = 0.0,
|
750
774
|
num_layers: int = 1,
|
751
775
|
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
@@ -753,6 +777,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
753
777
|
resnet_time_scale_shift: str = "default",
|
754
778
|
resnet_act_fn: str = "swish",
|
755
779
|
resnet_groups: int = 32,
|
780
|
+
resnet_groups_out: Optional[int] = None,
|
756
781
|
resnet_pre_norm: bool = True,
|
757
782
|
num_attention_heads: int = 1,
|
758
783
|
output_scale_factor: float = 1.0,
|
@@ -764,6 +789,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
764
789
|
):
|
765
790
|
super().__init__()
|
766
791
|
|
792
|
+
out_channels = out_channels or in_channels
|
793
|
+
self.in_channels = in_channels
|
794
|
+
self.out_channels = out_channels
|
795
|
+
|
767
796
|
self.has_cross_attention = True
|
768
797
|
self.num_attention_heads = num_attention_heads
|
769
798
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
@@ -772,14 +801,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
772
801
|
if isinstance(transformer_layers_per_block, int):
|
773
802
|
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
774
803
|
|
804
|
+
resnet_groups_out = resnet_groups_out or resnet_groups
|
805
|
+
|
775
806
|
# there is always at least one resnet
|
776
807
|
resnets = [
|
777
808
|
ResnetBlock2D(
|
778
809
|
in_channels=in_channels,
|
779
|
-
out_channels=
|
810
|
+
out_channels=out_channels,
|
780
811
|
temb_channels=temb_channels,
|
781
812
|
eps=resnet_eps,
|
782
813
|
groups=resnet_groups,
|
814
|
+
groups_out=resnet_groups_out,
|
783
815
|
dropout=dropout,
|
784
816
|
time_embedding_norm=resnet_time_scale_shift,
|
785
817
|
non_linearity=resnet_act_fn,
|
@@ -794,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
794
826
|
attentions.append(
|
795
827
|
Transformer2DModel(
|
796
828
|
num_attention_heads,
|
797
|
-
|
798
|
-
in_channels=
|
829
|
+
out_channels // num_attention_heads,
|
830
|
+
in_channels=out_channels,
|
799
831
|
num_layers=transformer_layers_per_block[i],
|
800
832
|
cross_attention_dim=cross_attention_dim,
|
801
|
-
norm_num_groups=
|
833
|
+
norm_num_groups=resnet_groups_out,
|
802
834
|
use_linear_projection=use_linear_projection,
|
803
835
|
upcast_attention=upcast_attention,
|
804
836
|
attention_type=attention_type,
|
@@ -808,8 +840,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
808
840
|
attentions.append(
|
809
841
|
DualTransformer2DModel(
|
810
842
|
num_attention_heads,
|
811
|
-
|
812
|
-
in_channels=
|
843
|
+
out_channels // num_attention_heads,
|
844
|
+
in_channels=out_channels,
|
813
845
|
num_layers=1,
|
814
846
|
cross_attention_dim=cross_attention_dim,
|
815
847
|
norm_num_groups=resnet_groups,
|
@@ -817,11 +849,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
817
849
|
)
|
818
850
|
resnets.append(
|
819
851
|
ResnetBlock2D(
|
820
|
-
in_channels=
|
821
|
-
out_channels=
|
852
|
+
in_channels=out_channels,
|
853
|
+
out_channels=out_channels,
|
822
854
|
temb_channels=temb_channels,
|
823
855
|
eps=resnet_eps,
|
824
|
-
groups=
|
856
|
+
groups=resnet_groups_out,
|
825
857
|
dropout=dropout,
|
826
858
|
time_embedding_norm=resnet_time_scale_shift,
|
827
859
|
non_linearity=resnet_act_fn,
|
@@ -837,20 +869,20 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
837
869
|
|
838
870
|
def forward(
|
839
871
|
self,
|
840
|
-
hidden_states: torch.
|
841
|
-
temb: Optional[torch.
|
842
|
-
encoder_hidden_states: Optional[torch.
|
843
|
-
attention_mask: Optional[torch.
|
872
|
+
hidden_states: torch.Tensor,
|
873
|
+
temb: Optional[torch.Tensor] = None,
|
874
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
875
|
+
attention_mask: Optional[torch.Tensor] = None,
|
844
876
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
845
|
-
encoder_attention_mask: Optional[torch.
|
846
|
-
) -> torch.
|
877
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
878
|
+
) -> torch.Tensor:
|
847
879
|
if cross_attention_kwargs is not None:
|
848
880
|
if cross_attention_kwargs.get("scale", None) is not None:
|
849
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
881
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
850
882
|
|
851
883
|
hidden_states = self.resnets[0](hidden_states, temb)
|
852
884
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
853
|
-
if
|
885
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
854
886
|
|
855
887
|
def create_custom_forward(module, return_dict=None):
|
856
888
|
def custom_forward(*inputs):
|
@@ -977,16 +1009,16 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
977
1009
|
|
978
1010
|
def forward(
|
979
1011
|
self,
|
980
|
-
hidden_states: torch.
|
981
|
-
temb: Optional[torch.
|
982
|
-
encoder_hidden_states: Optional[torch.
|
983
|
-
attention_mask: Optional[torch.
|
1012
|
+
hidden_states: torch.Tensor,
|
1013
|
+
temb: Optional[torch.Tensor] = None,
|
1014
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1015
|
+
attention_mask: Optional[torch.Tensor] = None,
|
984
1016
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
985
|
-
encoder_attention_mask: Optional[torch.
|
986
|
-
) -> torch.
|
1017
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1018
|
+
) -> torch.Tensor:
|
987
1019
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
988
1020
|
if cross_attention_kwargs.get("scale", None) is not None:
|
989
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1021
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
990
1022
|
|
991
1023
|
if attention_mask is None:
|
992
1024
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -1107,23 +1139,46 @@ class AttnDownBlock2D(nn.Module):
|
|
1107
1139
|
else:
|
1108
1140
|
self.downsamplers = None
|
1109
1141
|
|
1142
|
+
self.gradient_checkpointing = False
|
1143
|
+
|
1110
1144
|
def forward(
|
1111
1145
|
self,
|
1112
|
-
hidden_states: torch.
|
1113
|
-
temb: Optional[torch.
|
1146
|
+
hidden_states: torch.Tensor,
|
1147
|
+
temb: Optional[torch.Tensor] = None,
|
1114
1148
|
upsample_size: Optional[int] = None,
|
1115
1149
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1116
|
-
) -> Tuple[torch.
|
1150
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1117
1151
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1118
1152
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1119
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1153
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1120
1154
|
|
1121
1155
|
output_states = ()
|
1122
1156
|
|
1123
1157
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1158
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1159
|
+
|
1160
|
+
def create_custom_forward(module, return_dict=None):
|
1161
|
+
def custom_forward(*inputs):
|
1162
|
+
if return_dict is not None:
|
1163
|
+
return module(*inputs, return_dict=return_dict)
|
1164
|
+
else:
|
1165
|
+
return module(*inputs)
|
1166
|
+
|
1167
|
+
return custom_forward
|
1168
|
+
|
1169
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1170
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1171
|
+
create_custom_forward(resnet),
|
1172
|
+
hidden_states,
|
1173
|
+
temb,
|
1174
|
+
**ckpt_kwargs,
|
1175
|
+
)
|
1176
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1177
|
+
output_states = output_states + (hidden_states,)
|
1178
|
+
else:
|
1179
|
+
hidden_states = resnet(hidden_states, temb)
|
1180
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1181
|
+
output_states = output_states + (hidden_states,)
|
1127
1182
|
|
1128
1183
|
if self.downsamplers is not None:
|
1129
1184
|
for downsampler in self.downsamplers:
|
@@ -1231,24 +1286,24 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1231
1286
|
|
1232
1287
|
def forward(
|
1233
1288
|
self,
|
1234
|
-
hidden_states: torch.
|
1235
|
-
temb: Optional[torch.
|
1236
|
-
encoder_hidden_states: Optional[torch.
|
1237
|
-
attention_mask: Optional[torch.
|
1289
|
+
hidden_states: torch.Tensor,
|
1290
|
+
temb: Optional[torch.Tensor] = None,
|
1291
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1292
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1238
1293
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1239
|
-
encoder_attention_mask: Optional[torch.
|
1240
|
-
additional_residuals: Optional[torch.
|
1241
|
-
) -> Tuple[torch.
|
1294
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1295
|
+
additional_residuals: Optional[torch.Tensor] = None,
|
1296
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1242
1297
|
if cross_attention_kwargs is not None:
|
1243
1298
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1244
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
1299
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1245
1300
|
|
1246
1301
|
output_states = ()
|
1247
1302
|
|
1248
1303
|
blocks = list(zip(self.resnets, self.attentions))
|
1249
1304
|
|
1250
1305
|
for i, (resnet, attn) in enumerate(blocks):
|
1251
|
-
if
|
1306
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1252
1307
|
|
1253
1308
|
def create_custom_forward(module, return_dict=None):
|
1254
1309
|
def custom_forward(*inputs):
|
@@ -1353,8 +1408,8 @@ class DownBlock2D(nn.Module):
|
|
1353
1408
|
self.gradient_checkpointing = False
|
1354
1409
|
|
1355
1410
|
def forward(
|
1356
|
-
self, hidden_states: torch.
|
1357
|
-
) -> Tuple[torch.
|
1411
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
1412
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1358
1413
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1359
1414
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1360
1415
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1362,7 +1417,7 @@ class DownBlock2D(nn.Module):
|
|
1362
1417
|
output_states = ()
|
1363
1418
|
|
1364
1419
|
for resnet in self.resnets:
|
1365
|
-
if
|
1420
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1366
1421
|
|
1367
1422
|
def create_custom_forward(module):
|
1368
1423
|
def custom_forward(*inputs):
|
@@ -1456,7 +1511,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
1456
1511
|
else:
|
1457
1512
|
self.downsamplers = None
|
1458
1513
|
|
1459
|
-
def forward(self, hidden_states: torch.
|
1514
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1460
1515
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1461
1516
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1462
1517
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1558,7 +1613,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1558
1613
|
else:
|
1559
1614
|
self.downsamplers = None
|
1560
1615
|
|
1561
|
-
def forward(self, hidden_states: torch.
|
1616
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1562
1617
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1563
1618
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1564
1619
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1657,12 +1712,12 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1657
1712
|
|
1658
1713
|
def forward(
|
1659
1714
|
self,
|
1660
|
-
hidden_states: torch.
|
1661
|
-
temb: Optional[torch.
|
1662
|
-
skip_sample: Optional[torch.
|
1715
|
+
hidden_states: torch.Tensor,
|
1716
|
+
temb: Optional[torch.Tensor] = None,
|
1717
|
+
skip_sample: Optional[torch.Tensor] = None,
|
1663
1718
|
*args,
|
1664
1719
|
**kwargs,
|
1665
|
-
) -> Tuple[torch.
|
1720
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
|
1666
1721
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1667
1722
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1668
1723
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1748,12 +1803,12 @@ class SkipDownBlock2D(nn.Module):
|
|
1748
1803
|
|
1749
1804
|
def forward(
|
1750
1805
|
self,
|
1751
|
-
hidden_states: torch.
|
1752
|
-
temb: Optional[torch.
|
1753
|
-
skip_sample: Optional[torch.
|
1806
|
+
hidden_states: torch.Tensor,
|
1807
|
+
temb: Optional[torch.Tensor] = None,
|
1808
|
+
skip_sample: Optional[torch.Tensor] = None,
|
1754
1809
|
*args,
|
1755
1810
|
**kwargs,
|
1756
|
-
) -> Tuple[torch.
|
1811
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
|
1757
1812
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1758
1813
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1759
1814
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1841,8 +1896,8 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1841
1896
|
self.gradient_checkpointing = False
|
1842
1897
|
|
1843
1898
|
def forward(
|
1844
|
-
self, hidden_states: torch.
|
1845
|
-
) -> Tuple[torch.
|
1899
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
1900
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1846
1901
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
1902
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
1903
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -1850,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1850
1905
|
output_states = ()
|
1851
1906
|
|
1852
1907
|
for resnet in self.resnets:
|
1853
|
-
if
|
1908
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1854
1909
|
|
1855
1910
|
def create_custom_forward(module):
|
1856
1911
|
def custom_forward(*inputs):
|
@@ -1977,16 +2032,16 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1977
2032
|
|
1978
2033
|
def forward(
|
1979
2034
|
self,
|
1980
|
-
hidden_states: torch.
|
1981
|
-
temb: Optional[torch.
|
1982
|
-
encoder_hidden_states: Optional[torch.
|
1983
|
-
attention_mask: Optional[torch.
|
2035
|
+
hidden_states: torch.Tensor,
|
2036
|
+
temb: Optional[torch.Tensor] = None,
|
2037
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2038
|
+
attention_mask: Optional[torch.Tensor] = None,
|
1984
2039
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1985
|
-
encoder_attention_mask: Optional[torch.
|
1986
|
-
) -> Tuple[torch.
|
2040
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2041
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1987
2042
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1988
2043
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1989
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2044
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1990
2045
|
|
1991
2046
|
output_states = ()
|
1992
2047
|
|
@@ -2002,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
2002
2057
|
mask = attention_mask
|
2003
2058
|
|
2004
2059
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2005
|
-
if
|
2060
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2006
2061
|
|
2007
2062
|
def create_custom_forward(module, return_dict=None):
|
2008
2063
|
def custom_forward(*inputs):
|
@@ -2088,8 +2143,8 @@ class KDownBlock2D(nn.Module):
|
|
2088
2143
|
self.gradient_checkpointing = False
|
2089
2144
|
|
2090
2145
|
def forward(
|
2091
|
-
self, hidden_states: torch.
|
2092
|
-
) -> Tuple[torch.
|
2146
|
+
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
|
2147
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2093
2148
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2094
2149
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2095
2150
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2097,7 +2152,7 @@ class KDownBlock2D(nn.Module):
|
|
2097
2152
|
output_states = ()
|
2098
2153
|
|
2099
2154
|
for resnet in self.resnets:
|
2100
|
-
if
|
2155
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2101
2156
|
|
2102
2157
|
def create_custom_forward(module):
|
2103
2158
|
def custom_forward(*inputs):
|
@@ -2192,21 +2247,21 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2192
2247
|
|
2193
2248
|
def forward(
|
2194
2249
|
self,
|
2195
|
-
hidden_states: torch.
|
2196
|
-
temb: Optional[torch.
|
2197
|
-
encoder_hidden_states: Optional[torch.
|
2198
|
-
attention_mask: Optional[torch.
|
2250
|
+
hidden_states: torch.Tensor,
|
2251
|
+
temb: Optional[torch.Tensor] = None,
|
2252
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2253
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2199
2254
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2200
|
-
encoder_attention_mask: Optional[torch.
|
2201
|
-
) -> Tuple[torch.
|
2255
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2256
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
2202
2257
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2203
2258
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2204
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2259
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2205
2260
|
|
2206
2261
|
output_states = ()
|
2207
2262
|
|
2208
2263
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2209
|
-
if
|
2264
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2210
2265
|
|
2211
2266
|
def create_custom_forward(module, return_dict=None):
|
2212
2267
|
def custom_forward(*inputs):
|
@@ -2345,17 +2400,18 @@ class AttnUpBlock2D(nn.Module):
|
|
2345
2400
|
else:
|
2346
2401
|
self.upsamplers = None
|
2347
2402
|
|
2403
|
+
self.gradient_checkpointing = False
|
2348
2404
|
self.resolution_idx = resolution_idx
|
2349
2405
|
|
2350
2406
|
def forward(
|
2351
2407
|
self,
|
2352
|
-
hidden_states: torch.
|
2353
|
-
res_hidden_states_tuple: Tuple[torch.
|
2354
|
-
temb: Optional[torch.
|
2408
|
+
hidden_states: torch.Tensor,
|
2409
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2410
|
+
temb: Optional[torch.Tensor] = None,
|
2355
2411
|
upsample_size: Optional[int] = None,
|
2356
2412
|
*args,
|
2357
2413
|
**kwargs,
|
2358
|
-
) -> torch.
|
2414
|
+
) -> torch.Tensor:
|
2359
2415
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2360
2416
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2361
2417
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2366,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
|
|
2366
2422
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2367
2423
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2368
2424
|
|
2369
|
-
|
2370
|
-
|
2425
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2426
|
+
|
2427
|
+
def create_custom_forward(module, return_dict=None):
|
2428
|
+
def custom_forward(*inputs):
|
2429
|
+
if return_dict is not None:
|
2430
|
+
return module(*inputs, return_dict=return_dict)
|
2431
|
+
else:
|
2432
|
+
return module(*inputs)
|
2433
|
+
|
2434
|
+
return custom_forward
|
2435
|
+
|
2436
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2437
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2438
|
+
create_custom_forward(resnet),
|
2439
|
+
hidden_states,
|
2440
|
+
temb,
|
2441
|
+
**ckpt_kwargs,
|
2442
|
+
)
|
2443
|
+
hidden_states = attn(hidden_states)
|
2444
|
+
else:
|
2445
|
+
hidden_states = resnet(hidden_states, temb)
|
2446
|
+
hidden_states = attn(hidden_states)
|
2371
2447
|
|
2372
2448
|
if self.upsamplers is not None:
|
2373
2449
|
for upsampler in self.upsamplers:
|
@@ -2472,18 +2548,18 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2472
2548
|
|
2473
2549
|
def forward(
|
2474
2550
|
self,
|
2475
|
-
hidden_states: torch.
|
2476
|
-
res_hidden_states_tuple: Tuple[torch.
|
2477
|
-
temb: Optional[torch.
|
2478
|
-
encoder_hidden_states: Optional[torch.
|
2551
|
+
hidden_states: torch.Tensor,
|
2552
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2553
|
+
temb: Optional[torch.Tensor] = None,
|
2554
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
2479
2555
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2480
2556
|
upsample_size: Optional[int] = None,
|
2481
|
-
attention_mask: Optional[torch.
|
2482
|
-
encoder_attention_mask: Optional[torch.
|
2483
|
-
) -> torch.
|
2557
|
+
attention_mask: Optional[torch.Tensor] = None,
|
2558
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
2559
|
+
) -> torch.Tensor:
|
2484
2560
|
if cross_attention_kwargs is not None:
|
2485
2561
|
if cross_attention_kwargs.get("scale", None) is not None:
|
2486
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
2562
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
2487
2563
|
|
2488
2564
|
is_freeu_enabled = (
|
2489
2565
|
getattr(self, "s1", None)
|
@@ -2511,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2511
2587
|
|
2512
2588
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2513
2589
|
|
2514
|
-
if
|
2590
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2515
2591
|
|
2516
2592
|
def create_custom_forward(module, return_dict=None):
|
2517
2593
|
def custom_forward(*inputs):
|
@@ -2607,13 +2683,13 @@ class UpBlock2D(nn.Module):
|
|
2607
2683
|
|
2608
2684
|
def forward(
|
2609
2685
|
self,
|
2610
|
-
hidden_states: torch.
|
2611
|
-
res_hidden_states_tuple: Tuple[torch.
|
2612
|
-
temb: Optional[torch.
|
2686
|
+
hidden_states: torch.Tensor,
|
2687
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
2688
|
+
temb: Optional[torch.Tensor] = None,
|
2613
2689
|
upsample_size: Optional[int] = None,
|
2614
2690
|
*args,
|
2615
2691
|
**kwargs,
|
2616
|
-
) -> torch.
|
2692
|
+
) -> torch.Tensor:
|
2617
2693
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2618
2694
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2619
2695
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -2644,7 +2720,7 @@ class UpBlock2D(nn.Module):
|
|
2644
2720
|
|
2645
2721
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2646
2722
|
|
2647
|
-
if
|
2723
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2648
2724
|
|
2649
2725
|
def create_custom_forward(module):
|
2650
2726
|
def custom_forward(*inputs):
|
@@ -2732,7 +2808,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
2732
2808
|
|
2733
2809
|
self.resolution_idx = resolution_idx
|
2734
2810
|
|
2735
|
-
def forward(self, hidden_states: torch.
|
2811
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2736
2812
|
for resnet in self.resnets:
|
2737
2813
|
hidden_states = resnet(hidden_states, temb=temb)
|
2738
2814
|
|
@@ -2830,7 +2906,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2830
2906
|
|
2831
2907
|
self.resolution_idx = resolution_idx
|
2832
2908
|
|
2833
|
-
def forward(self, hidden_states: torch.
|
2909
|
+
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2834
2910
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2835
2911
|
hidden_states = resnet(hidden_states, temb=temb)
|
2836
2912
|
hidden_states = attn(hidden_states, temb=temb)
|
@@ -2938,13 +3014,13 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2938
3014
|
|
2939
3015
|
def forward(
|
2940
3016
|
self,
|
2941
|
-
hidden_states: torch.
|
2942
|
-
res_hidden_states_tuple: Tuple[torch.
|
2943
|
-
temb: Optional[torch.
|
3017
|
+
hidden_states: torch.Tensor,
|
3018
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3019
|
+
temb: Optional[torch.Tensor] = None,
|
2944
3020
|
skip_sample=None,
|
2945
3021
|
*args,
|
2946
3022
|
**kwargs,
|
2947
|
-
) -> Tuple[torch.
|
3023
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2948
3024
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2949
3025
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2950
3026
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3050,13 +3126,13 @@ class SkipUpBlock2D(nn.Module):
|
|
3050
3126
|
|
3051
3127
|
def forward(
|
3052
3128
|
self,
|
3053
|
-
hidden_states: torch.
|
3054
|
-
res_hidden_states_tuple: Tuple[torch.
|
3055
|
-
temb: Optional[torch.
|
3129
|
+
hidden_states: torch.Tensor,
|
3130
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3131
|
+
temb: Optional[torch.Tensor] = None,
|
3056
3132
|
skip_sample=None,
|
3057
3133
|
*args,
|
3058
3134
|
**kwargs,
|
3059
|
-
) -> Tuple[torch.
|
3135
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
3060
3136
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3061
3137
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3062
3138
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3157,13 +3233,13 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3157
3233
|
|
3158
3234
|
def forward(
|
3159
3235
|
self,
|
3160
|
-
hidden_states: torch.
|
3161
|
-
res_hidden_states_tuple: Tuple[torch.
|
3162
|
-
temb: Optional[torch.
|
3236
|
+
hidden_states: torch.Tensor,
|
3237
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3238
|
+
temb: Optional[torch.Tensor] = None,
|
3163
3239
|
upsample_size: Optional[int] = None,
|
3164
3240
|
*args,
|
3165
3241
|
**kwargs,
|
3166
|
-
) -> torch.
|
3242
|
+
) -> torch.Tensor:
|
3167
3243
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3168
3244
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3169
3245
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3174,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3174
3250
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3175
3251
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3176
3252
|
|
3177
|
-
if
|
3253
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3178
3254
|
|
3179
3255
|
def create_custom_forward(module):
|
3180
3256
|
def custom_forward(*inputs):
|
@@ -3301,18 +3377,18 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3301
3377
|
|
3302
3378
|
def forward(
|
3303
3379
|
self,
|
3304
|
-
hidden_states: torch.
|
3305
|
-
res_hidden_states_tuple: Tuple[torch.
|
3306
|
-
temb: Optional[torch.
|
3307
|
-
encoder_hidden_states: Optional[torch.
|
3380
|
+
hidden_states: torch.Tensor,
|
3381
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3382
|
+
temb: Optional[torch.Tensor] = None,
|
3383
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3308
3384
|
upsample_size: Optional[int] = None,
|
3309
|
-
attention_mask: Optional[torch.
|
3385
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3310
3386
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3311
|
-
encoder_attention_mask: Optional[torch.
|
3312
|
-
) -> torch.
|
3387
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3388
|
+
) -> torch.Tensor:
|
3313
3389
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3314
3390
|
if cross_attention_kwargs.get("scale", None) is not None:
|
3315
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
3391
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
3316
3392
|
|
3317
3393
|
if attention_mask is None:
|
3318
3394
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -3332,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3332
3408
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3333
3409
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3334
3410
|
|
3335
|
-
if
|
3411
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3336
3412
|
|
3337
3413
|
def create_custom_forward(module, return_dict=None):
|
3338
3414
|
def custom_forward(*inputs):
|
@@ -3419,13 +3495,13 @@ class KUpBlock2D(nn.Module):
|
|
3419
3495
|
|
3420
3496
|
def forward(
|
3421
3497
|
self,
|
3422
|
-
hidden_states: torch.
|
3423
|
-
res_hidden_states_tuple: Tuple[torch.
|
3424
|
-
temb: Optional[torch.
|
3498
|
+
hidden_states: torch.Tensor,
|
3499
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3500
|
+
temb: Optional[torch.Tensor] = None,
|
3425
3501
|
upsample_size: Optional[int] = None,
|
3426
3502
|
*args,
|
3427
3503
|
**kwargs,
|
3428
|
-
) -> torch.
|
3504
|
+
) -> torch.Tensor:
|
3429
3505
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3430
3506
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3431
3507
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -3435,7 +3511,7 @@ class KUpBlock2D(nn.Module):
|
|
3435
3511
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3436
3512
|
|
3437
3513
|
for resnet in self.resnets:
|
3438
|
-
if
|
3514
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3439
3515
|
|
3440
3516
|
def create_custom_forward(module):
|
3441
3517
|
def custom_forward(*inputs):
|
@@ -3549,21 +3625,21 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3549
3625
|
|
3550
3626
|
def forward(
|
3551
3627
|
self,
|
3552
|
-
hidden_states: torch.
|
3553
|
-
res_hidden_states_tuple: Tuple[torch.
|
3554
|
-
temb: Optional[torch.
|
3555
|
-
encoder_hidden_states: Optional[torch.
|
3628
|
+
hidden_states: torch.Tensor,
|
3629
|
+
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
3630
|
+
temb: Optional[torch.Tensor] = None,
|
3631
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3556
3632
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3557
3633
|
upsample_size: Optional[int] = None,
|
3558
|
-
attention_mask: Optional[torch.
|
3559
|
-
encoder_attention_mask: Optional[torch.
|
3560
|
-
) -> torch.
|
3634
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3635
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3636
|
+
) -> torch.Tensor:
|
3561
3637
|
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
3562
3638
|
if res_hidden_states_tuple is not None:
|
3563
3639
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3564
3640
|
|
3565
3641
|
for resnet, attn in zip(self.resnets, self.attentions):
|
3566
|
-
if
|
3642
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3567
3643
|
|
3568
3644
|
def create_custom_forward(module, return_dict=None):
|
3569
3645
|
def custom_forward(*inputs):
|
@@ -3675,26 +3751,26 @@ class KAttentionBlock(nn.Module):
|
|
3675
3751
|
cross_attention_norm=cross_attention_norm,
|
3676
3752
|
)
|
3677
3753
|
|
3678
|
-
def _to_3d(self, hidden_states: torch.
|
3754
|
+
def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
|
3679
3755
|
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
|
3680
3756
|
|
3681
|
-
def _to_4d(self, hidden_states: torch.
|
3757
|
+
def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
|
3682
3758
|
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
|
3683
3759
|
|
3684
3760
|
def forward(
|
3685
3761
|
self,
|
3686
|
-
hidden_states: torch.
|
3687
|
-
encoder_hidden_states: Optional[torch.
|
3762
|
+
hidden_states: torch.Tensor,
|
3763
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
3688
3764
|
# TODO: mark emb as non-optional (self.norm2 requires it).
|
3689
3765
|
# requires assessing impact of change to positional param interface.
|
3690
|
-
emb: Optional[torch.
|
3691
|
-
attention_mask: Optional[torch.
|
3766
|
+
emb: Optional[torch.Tensor] = None,
|
3767
|
+
attention_mask: Optional[torch.Tensor] = None,
|
3692
3768
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
3693
|
-
encoder_attention_mask: Optional[torch.
|
3694
|
-
) -> torch.
|
3769
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
3770
|
+
) -> torch.Tensor:
|
3695
3771
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3696
3772
|
if cross_attention_kwargs.get("scale", None) is not None:
|
3697
|
-
logger.warning("Passing `scale` to `cross_attention_kwargs` is
|
3773
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
3698
3774
|
|
3699
3775
|
# 1. Self-Attention
|
3700
3776
|
if self.add_self_attention:
|