diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +28 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +278 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
- diffusers-0.27.0.dist-info/RECORD +399 -0
- diffusers-0.26.2.dist-info/RECORD +0 -384
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -18,7 +18,7 @@ import torch
|
|
18
18
|
import torch.nn.functional as F
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
-
from ...utils import is_torch_version, logging
|
21
|
+
from ...utils import deprecate, is_torch_version, logging
|
22
22
|
from ...utils.torch_utils import apply_freeu
|
23
23
|
from ..activations import get_activation
|
24
24
|
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
@@ -69,7 +69,7 @@ def get_down_block(
|
|
69
69
|
):
|
70
70
|
# If attn head dim is not defined, we default it to the number of heads
|
71
71
|
if attention_head_dim is None:
|
72
|
-
logger.
|
72
|
+
logger.warning(
|
73
73
|
f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
74
74
|
)
|
75
75
|
attention_head_dim = num_attention_heads
|
@@ -249,6 +249,81 @@ def get_down_block(
|
|
249
249
|
raise ValueError(f"{down_block_type} does not exist.")
|
250
250
|
|
251
251
|
|
252
|
+
def get_mid_block(
|
253
|
+
mid_block_type: str,
|
254
|
+
temb_channels: int,
|
255
|
+
in_channels: int,
|
256
|
+
resnet_eps: float,
|
257
|
+
resnet_act_fn: str,
|
258
|
+
resnet_groups: int,
|
259
|
+
output_scale_factor: float = 1.0,
|
260
|
+
transformer_layers_per_block: int = 1,
|
261
|
+
num_attention_heads: Optional[int] = None,
|
262
|
+
cross_attention_dim: Optional[int] = None,
|
263
|
+
dual_cross_attention: bool = False,
|
264
|
+
use_linear_projection: bool = False,
|
265
|
+
mid_block_only_cross_attention: bool = False,
|
266
|
+
upcast_attention: bool = False,
|
267
|
+
resnet_time_scale_shift: str = "default",
|
268
|
+
attention_type: str = "default",
|
269
|
+
resnet_skip_time_act: bool = False,
|
270
|
+
cross_attention_norm: Optional[str] = None,
|
271
|
+
attention_head_dim: Optional[int] = 1,
|
272
|
+
dropout: float = 0.0,
|
273
|
+
):
|
274
|
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
275
|
+
return UNetMidBlock2DCrossAttn(
|
276
|
+
transformer_layers_per_block=transformer_layers_per_block,
|
277
|
+
in_channels=in_channels,
|
278
|
+
temb_channels=temb_channels,
|
279
|
+
dropout=dropout,
|
280
|
+
resnet_eps=resnet_eps,
|
281
|
+
resnet_act_fn=resnet_act_fn,
|
282
|
+
output_scale_factor=output_scale_factor,
|
283
|
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
284
|
+
cross_attention_dim=cross_attention_dim,
|
285
|
+
num_attention_heads=num_attention_heads,
|
286
|
+
resnet_groups=resnet_groups,
|
287
|
+
dual_cross_attention=dual_cross_attention,
|
288
|
+
use_linear_projection=use_linear_projection,
|
289
|
+
upcast_attention=upcast_attention,
|
290
|
+
attention_type=attention_type,
|
291
|
+
)
|
292
|
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
293
|
+
return UNetMidBlock2DSimpleCrossAttn(
|
294
|
+
in_channels=in_channels,
|
295
|
+
temb_channels=temb_channels,
|
296
|
+
dropout=dropout,
|
297
|
+
resnet_eps=resnet_eps,
|
298
|
+
resnet_act_fn=resnet_act_fn,
|
299
|
+
output_scale_factor=output_scale_factor,
|
300
|
+
cross_attention_dim=cross_attention_dim,
|
301
|
+
attention_head_dim=attention_head_dim,
|
302
|
+
resnet_groups=resnet_groups,
|
303
|
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
304
|
+
skip_time_act=resnet_skip_time_act,
|
305
|
+
only_cross_attention=mid_block_only_cross_attention,
|
306
|
+
cross_attention_norm=cross_attention_norm,
|
307
|
+
)
|
308
|
+
elif mid_block_type == "UNetMidBlock2D":
|
309
|
+
return UNetMidBlock2D(
|
310
|
+
in_channels=in_channels,
|
311
|
+
temb_channels=temb_channels,
|
312
|
+
dropout=dropout,
|
313
|
+
num_layers=0,
|
314
|
+
resnet_eps=resnet_eps,
|
315
|
+
resnet_act_fn=resnet_act_fn,
|
316
|
+
output_scale_factor=output_scale_factor,
|
317
|
+
resnet_groups=resnet_groups,
|
318
|
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
319
|
+
add_attention=False,
|
320
|
+
)
|
321
|
+
elif mid_block_type is None:
|
322
|
+
return None
|
323
|
+
else:
|
324
|
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
325
|
+
|
326
|
+
|
252
327
|
def get_up_block(
|
253
328
|
up_block_type: str,
|
254
329
|
num_layers: int,
|
@@ -279,7 +354,7 @@ def get_up_block(
|
|
279
354
|
) -> nn.Module:
|
280
355
|
# If attn head dim is not defined, we default it to the number of heads
|
281
356
|
if attention_head_dim is None:
|
282
|
-
logger.
|
357
|
+
logger.warning(
|
283
358
|
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
284
359
|
)
|
285
360
|
attention_head_dim = num_attention_heads
|
@@ -598,7 +673,7 @@ class UNetMidBlock2D(nn.Module):
|
|
598
673
|
attentions = []
|
599
674
|
|
600
675
|
if attention_head_dim is None:
|
601
|
-
logger.
|
676
|
+
logger.warning(
|
602
677
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
603
678
|
)
|
604
679
|
attention_head_dim = in_channels
|
@@ -769,8 +844,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
769
844
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
770
845
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
771
846
|
) -> torch.FloatTensor:
|
772
|
-
|
773
|
-
|
847
|
+
if cross_attention_kwargs is not None:
|
848
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
849
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
850
|
+
|
851
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
774
852
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
775
853
|
if self.training and self.gradient_checkpointing:
|
776
854
|
|
@@ -807,7 +885,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
807
885
|
encoder_attention_mask=encoder_attention_mask,
|
808
886
|
return_dict=False,
|
809
887
|
)[0]
|
810
|
-
hidden_states = resnet(hidden_states, temb
|
888
|
+
hidden_states = resnet(hidden_states, temb)
|
811
889
|
|
812
890
|
return hidden_states
|
813
891
|
|
@@ -907,7 +985,8 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
907
985
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
908
986
|
) -> torch.FloatTensor:
|
909
987
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
910
|
-
|
988
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
989
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
911
990
|
|
912
991
|
if attention_mask is None:
|
913
992
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -920,7 +999,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
920
999
|
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
921
1000
|
mask = attention_mask
|
922
1001
|
|
923
|
-
hidden_states = self.resnets[0](hidden_states, temb
|
1002
|
+
hidden_states = self.resnets[0](hidden_states, temb)
|
924
1003
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
925
1004
|
# attn
|
926
1005
|
hidden_states = attn(
|
@@ -931,7 +1010,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|
931
1010
|
)
|
932
1011
|
|
933
1012
|
# resnet
|
934
|
-
hidden_states = resnet(hidden_states, temb
|
1013
|
+
hidden_states = resnet(hidden_states, temb)
|
935
1014
|
|
936
1015
|
return hidden_states
|
937
1016
|
|
@@ -960,7 +1039,7 @@ class AttnDownBlock2D(nn.Module):
|
|
960
1039
|
self.downsample_type = downsample_type
|
961
1040
|
|
962
1041
|
if attention_head_dim is None:
|
963
|
-
logger.
|
1042
|
+
logger.warning(
|
964
1043
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
965
1044
|
)
|
966
1045
|
attention_head_dim = out_channels
|
@@ -1036,23 +1115,22 @@ class AttnDownBlock2D(nn.Module):
|
|
1036
1115
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1037
1116
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1038
1117
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1039
|
-
|
1040
|
-
|
1118
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1119
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1041
1120
|
|
1042
1121
|
output_states = ()
|
1043
1122
|
|
1044
1123
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1045
|
-
|
1046
|
-
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
1124
|
+
hidden_states = resnet(hidden_states, temb)
|
1047
1125
|
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1048
1126
|
output_states = output_states + (hidden_states,)
|
1049
1127
|
|
1050
1128
|
if self.downsamplers is not None:
|
1051
1129
|
for downsampler in self.downsamplers:
|
1052
1130
|
if self.downsample_type == "resnet":
|
1053
|
-
hidden_states = downsampler(hidden_states, temb=temb
|
1131
|
+
hidden_states = downsampler(hidden_states, temb=temb)
|
1054
1132
|
else:
|
1055
|
-
hidden_states = downsampler(hidden_states
|
1133
|
+
hidden_states = downsampler(hidden_states)
|
1056
1134
|
|
1057
1135
|
output_states += (hidden_states,)
|
1058
1136
|
|
@@ -1161,9 +1239,11 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1161
1239
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1162
1240
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
1163
1241
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1164
|
-
|
1242
|
+
if cross_attention_kwargs is not None:
|
1243
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1244
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1165
1245
|
|
1166
|
-
|
1246
|
+
output_states = ()
|
1167
1247
|
|
1168
1248
|
blocks = list(zip(self.resnets, self.attentions))
|
1169
1249
|
|
@@ -1195,7 +1275,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1195
1275
|
return_dict=False,
|
1196
1276
|
)[0]
|
1197
1277
|
else:
|
1198
|
-
hidden_states = resnet(hidden_states, temb
|
1278
|
+
hidden_states = resnet(hidden_states, temb)
|
1199
1279
|
hidden_states = attn(
|
1200
1280
|
hidden_states,
|
1201
1281
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -1213,7 +1293,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1213
1293
|
|
1214
1294
|
if self.downsamplers is not None:
|
1215
1295
|
for downsampler in self.downsamplers:
|
1216
|
-
hidden_states = downsampler(hidden_states
|
1296
|
+
hidden_states = downsampler(hidden_states)
|
1217
1297
|
|
1218
1298
|
output_states = output_states + (hidden_states,)
|
1219
1299
|
|
@@ -1273,8 +1353,12 @@ class DownBlock2D(nn.Module):
|
|
1273
1353
|
self.gradient_checkpointing = False
|
1274
1354
|
|
1275
1355
|
def forward(
|
1276
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
|
1356
|
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
1277
1357
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1358
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1359
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1360
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1361
|
+
|
1278
1362
|
output_states = ()
|
1279
1363
|
|
1280
1364
|
for resnet in self.resnets:
|
@@ -1295,13 +1379,13 @@ class DownBlock2D(nn.Module):
|
|
1295
1379
|
create_custom_forward(resnet), hidden_states, temb
|
1296
1380
|
)
|
1297
1381
|
else:
|
1298
|
-
hidden_states = resnet(hidden_states, temb
|
1382
|
+
hidden_states = resnet(hidden_states, temb)
|
1299
1383
|
|
1300
1384
|
output_states = output_states + (hidden_states,)
|
1301
1385
|
|
1302
1386
|
if self.downsamplers is not None:
|
1303
1387
|
for downsampler in self.downsamplers:
|
1304
|
-
hidden_states = downsampler(hidden_states
|
1388
|
+
hidden_states = downsampler(hidden_states)
|
1305
1389
|
|
1306
1390
|
output_states = output_states + (hidden_states,)
|
1307
1391
|
|
@@ -1372,13 +1456,17 @@ class DownEncoderBlock2D(nn.Module):
|
|
1372
1456
|
else:
|
1373
1457
|
self.downsamplers = None
|
1374
1458
|
|
1375
|
-
def forward(self, hidden_states: torch.FloatTensor,
|
1459
|
+
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1460
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1461
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1462
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1463
|
+
|
1376
1464
|
for resnet in self.resnets:
|
1377
|
-
hidden_states = resnet(hidden_states, temb=None
|
1465
|
+
hidden_states = resnet(hidden_states, temb=None)
|
1378
1466
|
|
1379
1467
|
if self.downsamplers is not None:
|
1380
1468
|
for downsampler in self.downsamplers:
|
1381
|
-
hidden_states = downsampler(hidden_states
|
1469
|
+
hidden_states = downsampler(hidden_states)
|
1382
1470
|
|
1383
1471
|
return hidden_states
|
1384
1472
|
|
@@ -1405,7 +1493,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1405
1493
|
attentions = []
|
1406
1494
|
|
1407
1495
|
if attention_head_dim is None:
|
1408
|
-
logger.
|
1496
|
+
logger.warning(
|
1409
1497
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
1410
1498
|
)
|
1411
1499
|
attention_head_dim = out_channels
|
@@ -1470,15 +1558,18 @@ class AttnDownEncoderBlock2D(nn.Module):
|
|
1470
1558
|
else:
|
1471
1559
|
self.downsamplers = None
|
1472
1560
|
|
1473
|
-
def forward(self, hidden_states: torch.FloatTensor,
|
1561
|
+
def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
1562
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1563
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1564
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1565
|
+
|
1474
1566
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1475
|
-
hidden_states = resnet(hidden_states, temb=None
|
1476
|
-
|
1477
|
-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1567
|
+
hidden_states = resnet(hidden_states, temb=None)
|
1568
|
+
hidden_states = attn(hidden_states)
|
1478
1569
|
|
1479
1570
|
if self.downsamplers is not None:
|
1480
1571
|
for downsampler in self.downsamplers:
|
1481
|
-
hidden_states = downsampler(hidden_states
|
1572
|
+
hidden_states = downsampler(hidden_states)
|
1482
1573
|
|
1483
1574
|
return hidden_states
|
1484
1575
|
|
@@ -1504,7 +1595,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1504
1595
|
self.resnets = nn.ModuleList([])
|
1505
1596
|
|
1506
1597
|
if attention_head_dim is None:
|
1507
|
-
logger.
|
1598
|
+
logger.warning(
|
1508
1599
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
1509
1600
|
)
|
1510
1601
|
attention_head_dim = out_channels
|
@@ -1569,18 +1660,22 @@ class AttnSkipDownBlock2D(nn.Module):
|
|
1569
1660
|
hidden_states: torch.FloatTensor,
|
1570
1661
|
temb: Optional[torch.FloatTensor] = None,
|
1571
1662
|
skip_sample: Optional[torch.FloatTensor] = None,
|
1572
|
-
|
1663
|
+
*args,
|
1664
|
+
**kwargs,
|
1573
1665
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
1666
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1667
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1668
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1669
|
+
|
1574
1670
|
output_states = ()
|
1575
1671
|
|
1576
1672
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1577
|
-
hidden_states = resnet(hidden_states, temb
|
1578
|
-
|
1579
|
-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1673
|
+
hidden_states = resnet(hidden_states, temb)
|
1674
|
+
hidden_states = attn(hidden_states)
|
1580
1675
|
output_states += (hidden_states,)
|
1581
1676
|
|
1582
1677
|
if self.downsamplers is not None:
|
1583
|
-
hidden_states = self.resnet_down(hidden_states, temb
|
1678
|
+
hidden_states = self.resnet_down(hidden_states, temb)
|
1584
1679
|
for downsampler in self.downsamplers:
|
1585
1680
|
skip_sample = downsampler(skip_sample)
|
1586
1681
|
|
@@ -1656,16 +1751,21 @@ class SkipDownBlock2D(nn.Module):
|
|
1656
1751
|
hidden_states: torch.FloatTensor,
|
1657
1752
|
temb: Optional[torch.FloatTensor] = None,
|
1658
1753
|
skip_sample: Optional[torch.FloatTensor] = None,
|
1659
|
-
|
1754
|
+
*args,
|
1755
|
+
**kwargs,
|
1660
1756
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
|
1757
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1758
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1759
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1760
|
+
|
1661
1761
|
output_states = ()
|
1662
1762
|
|
1663
1763
|
for resnet in self.resnets:
|
1664
|
-
hidden_states = resnet(hidden_states, temb
|
1764
|
+
hidden_states = resnet(hidden_states, temb)
|
1665
1765
|
output_states += (hidden_states,)
|
1666
1766
|
|
1667
1767
|
if self.downsamplers is not None:
|
1668
|
-
hidden_states = self.resnet_down(hidden_states, temb
|
1768
|
+
hidden_states = self.resnet_down(hidden_states, temb)
|
1669
1769
|
for downsampler in self.downsamplers:
|
1670
1770
|
skip_sample = downsampler(skip_sample)
|
1671
1771
|
|
@@ -1741,8 +1841,12 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1741
1841
|
self.gradient_checkpointing = False
|
1742
1842
|
|
1743
1843
|
def forward(
|
1744
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
|
1844
|
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
1745
1845
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1846
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1847
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1848
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
1849
|
+
|
1746
1850
|
output_states = ()
|
1747
1851
|
|
1748
1852
|
for resnet in self.resnets:
|
@@ -1763,13 +1867,13 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1763
1867
|
create_custom_forward(resnet), hidden_states, temb
|
1764
1868
|
)
|
1765
1869
|
else:
|
1766
|
-
hidden_states = resnet(hidden_states, temb
|
1870
|
+
hidden_states = resnet(hidden_states, temb)
|
1767
1871
|
|
1768
1872
|
output_states = output_states + (hidden_states,)
|
1769
1873
|
|
1770
1874
|
if self.downsamplers is not None:
|
1771
1875
|
for downsampler in self.downsamplers:
|
1772
|
-
hidden_states = downsampler(hidden_states, temb
|
1876
|
+
hidden_states = downsampler(hidden_states, temb)
|
1773
1877
|
|
1774
1878
|
output_states = output_states + (hidden_states,)
|
1775
1879
|
|
@@ -1880,10 +1984,11 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1880
1984
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1881
1985
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1882
1986
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
1883
|
-
output_states = ()
|
1884
1987
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
1988
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1989
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
1885
1990
|
|
1886
|
-
|
1991
|
+
output_states = ()
|
1887
1992
|
|
1888
1993
|
if attention_mask is None:
|
1889
1994
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
@@ -1916,7 +2021,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1916
2021
|
**cross_attention_kwargs,
|
1917
2022
|
)
|
1918
2023
|
else:
|
1919
|
-
hidden_states = resnet(hidden_states, temb
|
2024
|
+
hidden_states = resnet(hidden_states, temb)
|
1920
2025
|
|
1921
2026
|
hidden_states = attn(
|
1922
2027
|
hidden_states,
|
@@ -1929,7 +2034,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
1929
2034
|
|
1930
2035
|
if self.downsamplers is not None:
|
1931
2036
|
for downsampler in self.downsamplers:
|
1932
|
-
hidden_states = downsampler(hidden_states, temb
|
2037
|
+
hidden_states = downsampler(hidden_states, temb)
|
1933
2038
|
|
1934
2039
|
output_states = output_states + (hidden_states,)
|
1935
2040
|
|
@@ -1983,8 +2088,12 @@ class KDownBlock2D(nn.Module):
|
|
1983
2088
|
self.gradient_checkpointing = False
|
1984
2089
|
|
1985
2090
|
def forward(
|
1986
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None,
|
2091
|
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
|
1987
2092
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2093
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2094
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2095
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2096
|
+
|
1988
2097
|
output_states = ()
|
1989
2098
|
|
1990
2099
|
for resnet in self.resnets:
|
@@ -2005,7 +2114,7 @@ class KDownBlock2D(nn.Module):
|
|
2005
2114
|
create_custom_forward(resnet), hidden_states, temb
|
2006
2115
|
)
|
2007
2116
|
else:
|
2008
|
-
hidden_states = resnet(hidden_states, temb
|
2117
|
+
hidden_states = resnet(hidden_states, temb)
|
2009
2118
|
|
2010
2119
|
output_states += (hidden_states,)
|
2011
2120
|
|
@@ -2090,8 +2199,11 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2090
2199
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
2091
2200
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
2092
2201
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
2202
|
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
2203
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
2204
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
2205
|
+
|
2093
2206
|
output_states = ()
|
2094
|
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
2095
2207
|
|
2096
2208
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2097
2209
|
if self.training and self.gradient_checkpointing:
|
@@ -2121,7 +2233,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2121
2233
|
encoder_attention_mask=encoder_attention_mask,
|
2122
2234
|
)
|
2123
2235
|
else:
|
2124
|
-
hidden_states = resnet(hidden_states, temb
|
2236
|
+
hidden_states = resnet(hidden_states, temb)
|
2125
2237
|
hidden_states = attn(
|
2126
2238
|
hidden_states,
|
2127
2239
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -2169,7 +2281,7 @@ class AttnUpBlock2D(nn.Module):
|
|
2169
2281
|
self.upsample_type = upsample_type
|
2170
2282
|
|
2171
2283
|
if attention_head_dim is None:
|
2172
|
-
logger.
|
2284
|
+
logger.warning(
|
2173
2285
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
|
2174
2286
|
)
|
2175
2287
|
attention_head_dim = out_channels
|
@@ -2241,24 +2353,28 @@ class AttnUpBlock2D(nn.Module):
|
|
2241
2353
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2242
2354
|
temb: Optional[torch.FloatTensor] = None,
|
2243
2355
|
upsample_size: Optional[int] = None,
|
2244
|
-
|
2356
|
+
*args,
|
2357
|
+
**kwargs,
|
2245
2358
|
) -> torch.FloatTensor:
|
2359
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2360
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2361
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2362
|
+
|
2246
2363
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2247
2364
|
# pop res hidden states
|
2248
2365
|
res_hidden_states = res_hidden_states_tuple[-1]
|
2249
2366
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2250
2367
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2251
2368
|
|
2252
|
-
hidden_states = resnet(hidden_states, temb
|
2253
|
-
|
2254
|
-
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
2369
|
+
hidden_states = resnet(hidden_states, temb)
|
2370
|
+
hidden_states = attn(hidden_states)
|
2255
2371
|
|
2256
2372
|
if self.upsamplers is not None:
|
2257
2373
|
for upsampler in self.upsamplers:
|
2258
2374
|
if self.upsample_type == "resnet":
|
2259
|
-
hidden_states = upsampler(hidden_states, temb=temb
|
2375
|
+
hidden_states = upsampler(hidden_states, temb=temb)
|
2260
2376
|
else:
|
2261
|
-
hidden_states = upsampler(hidden_states
|
2377
|
+
hidden_states = upsampler(hidden_states)
|
2262
2378
|
|
2263
2379
|
return hidden_states
|
2264
2380
|
|
@@ -2365,7 +2481,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2365
2481
|
attention_mask: Optional[torch.FloatTensor] = None,
|
2366
2482
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
2367
2483
|
) -> torch.FloatTensor:
|
2368
|
-
|
2484
|
+
if cross_attention_kwargs is not None:
|
2485
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
2486
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
2487
|
+
|
2369
2488
|
is_freeu_enabled = (
|
2370
2489
|
getattr(self, "s1", None)
|
2371
2490
|
and getattr(self, "s2", None)
|
@@ -2419,7 +2538,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2419
2538
|
return_dict=False,
|
2420
2539
|
)[0]
|
2421
2540
|
else:
|
2422
|
-
hidden_states = resnet(hidden_states, temb
|
2541
|
+
hidden_states = resnet(hidden_states, temb)
|
2423
2542
|
hidden_states = attn(
|
2424
2543
|
hidden_states,
|
2425
2544
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -2431,7 +2550,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2431
2550
|
|
2432
2551
|
if self.upsamplers is not None:
|
2433
2552
|
for upsampler in self.upsamplers:
|
2434
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
2553
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
2435
2554
|
|
2436
2555
|
return hidden_states
|
2437
2556
|
|
@@ -2492,8 +2611,13 @@ class UpBlock2D(nn.Module):
|
|
2492
2611
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2493
2612
|
temb: Optional[torch.FloatTensor] = None,
|
2494
2613
|
upsample_size: Optional[int] = None,
|
2495
|
-
|
2614
|
+
*args,
|
2615
|
+
**kwargs,
|
2496
2616
|
) -> torch.FloatTensor:
|
2617
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2618
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2619
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2620
|
+
|
2497
2621
|
is_freeu_enabled = (
|
2498
2622
|
getattr(self, "s1", None)
|
2499
2623
|
and getattr(self, "s2", None)
|
@@ -2537,11 +2661,11 @@ class UpBlock2D(nn.Module):
|
|
2537
2661
|
create_custom_forward(resnet), hidden_states, temb
|
2538
2662
|
)
|
2539
2663
|
else:
|
2540
|
-
hidden_states = resnet(hidden_states, temb
|
2664
|
+
hidden_states = resnet(hidden_states, temb)
|
2541
2665
|
|
2542
2666
|
if self.upsamplers is not None:
|
2543
2667
|
for upsampler in self.upsamplers:
|
2544
|
-
hidden_states = upsampler(hidden_states, upsample_size
|
2668
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
2545
2669
|
|
2546
2670
|
return hidden_states
|
2547
2671
|
|
@@ -2608,11 +2732,9 @@ class UpDecoderBlock2D(nn.Module):
|
|
2608
2732
|
|
2609
2733
|
self.resolution_idx = resolution_idx
|
2610
2734
|
|
2611
|
-
def forward(
|
2612
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
2613
|
-
) -> torch.FloatTensor:
|
2735
|
+
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
2614
2736
|
for resnet in self.resnets:
|
2615
|
-
hidden_states = resnet(hidden_states, temb=temb
|
2737
|
+
hidden_states = resnet(hidden_states, temb=temb)
|
2616
2738
|
|
2617
2739
|
if self.upsamplers is not None:
|
2618
2740
|
for upsampler in self.upsamplers:
|
@@ -2644,7 +2766,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2644
2766
|
attentions = []
|
2645
2767
|
|
2646
2768
|
if attention_head_dim is None:
|
2647
|
-
logger.
|
2769
|
+
logger.warning(
|
2648
2770
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
2649
2771
|
)
|
2650
2772
|
attention_head_dim = out_channels
|
@@ -2708,17 +2830,14 @@ class AttnUpDecoderBlock2D(nn.Module):
|
|
2708
2830
|
|
2709
2831
|
self.resolution_idx = resolution_idx
|
2710
2832
|
|
2711
|
-
def forward(
|
2712
|
-
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
2713
|
-
) -> torch.FloatTensor:
|
2833
|
+
def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
2714
2834
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2715
|
-
hidden_states = resnet(hidden_states, temb=temb
|
2716
|
-
|
2717
|
-
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
|
2835
|
+
hidden_states = resnet(hidden_states, temb=temb)
|
2836
|
+
hidden_states = attn(hidden_states, temb=temb)
|
2718
2837
|
|
2719
2838
|
if self.upsamplers is not None:
|
2720
2839
|
for upsampler in self.upsamplers:
|
2721
|
-
hidden_states = upsampler(hidden_states
|
2840
|
+
hidden_states = upsampler(hidden_states)
|
2722
2841
|
|
2723
2842
|
return hidden_states
|
2724
2843
|
|
@@ -2766,7 +2885,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2766
2885
|
)
|
2767
2886
|
|
2768
2887
|
if attention_head_dim is None:
|
2769
|
-
logger.
|
2888
|
+
logger.warning(
|
2770
2889
|
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
|
2771
2890
|
)
|
2772
2891
|
attention_head_dim = out_channels
|
@@ -2823,18 +2942,22 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2823
2942
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2824
2943
|
temb: Optional[torch.FloatTensor] = None,
|
2825
2944
|
skip_sample=None,
|
2826
|
-
|
2945
|
+
*args,
|
2946
|
+
**kwargs,
|
2827
2947
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
2948
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
2949
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
2950
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
2951
|
+
|
2828
2952
|
for resnet in self.resnets:
|
2829
2953
|
# pop res hidden states
|
2830
2954
|
res_hidden_states = res_hidden_states_tuple[-1]
|
2831
2955
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2832
2956
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2833
2957
|
|
2834
|
-
hidden_states = resnet(hidden_states, temb
|
2958
|
+
hidden_states = resnet(hidden_states, temb)
|
2835
2959
|
|
2836
|
-
|
2837
|
-
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
|
2960
|
+
hidden_states = self.attentions[0](hidden_states)
|
2838
2961
|
|
2839
2962
|
if skip_sample is not None:
|
2840
2963
|
skip_sample = self.upsampler(skip_sample)
|
@@ -2848,7 +2971,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
|
2848
2971
|
|
2849
2972
|
skip_sample = skip_sample + skip_sample_states
|
2850
2973
|
|
2851
|
-
hidden_states = self.resnet_up(hidden_states, temb
|
2974
|
+
hidden_states = self.resnet_up(hidden_states, temb)
|
2852
2975
|
|
2853
2976
|
return hidden_states, skip_sample
|
2854
2977
|
|
@@ -2931,15 +3054,20 @@ class SkipUpBlock2D(nn.Module):
|
|
2931
3054
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
2932
3055
|
temb: Optional[torch.FloatTensor] = None,
|
2933
3056
|
skip_sample=None,
|
2934
|
-
|
3057
|
+
*args,
|
3058
|
+
**kwargs,
|
2935
3059
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
3060
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3061
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3062
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
3063
|
+
|
2936
3064
|
for resnet in self.resnets:
|
2937
3065
|
# pop res hidden states
|
2938
3066
|
res_hidden_states = res_hidden_states_tuple[-1]
|
2939
3067
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2940
3068
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2941
3069
|
|
2942
|
-
hidden_states = resnet(hidden_states, temb
|
3070
|
+
hidden_states = resnet(hidden_states, temb)
|
2943
3071
|
|
2944
3072
|
if skip_sample is not None:
|
2945
3073
|
skip_sample = self.upsampler(skip_sample)
|
@@ -2953,7 +3081,7 @@ class SkipUpBlock2D(nn.Module):
|
|
2953
3081
|
|
2954
3082
|
skip_sample = skip_sample + skip_sample_states
|
2955
3083
|
|
2956
|
-
hidden_states = self.resnet_up(hidden_states, temb
|
3084
|
+
hidden_states = self.resnet_up(hidden_states, temb)
|
2957
3085
|
|
2958
3086
|
return hidden_states, skip_sample
|
2959
3087
|
|
@@ -3033,8 +3161,13 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3033
3161
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
3034
3162
|
temb: Optional[torch.FloatTensor] = None,
|
3035
3163
|
upsample_size: Optional[int] = None,
|
3036
|
-
|
3164
|
+
*args,
|
3165
|
+
**kwargs,
|
3037
3166
|
) -> torch.FloatTensor:
|
3167
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3168
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3169
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
3170
|
+
|
3038
3171
|
for resnet in self.resnets:
|
3039
3172
|
# pop res hidden states
|
3040
3173
|
res_hidden_states = res_hidden_states_tuple[-1]
|
@@ -3058,11 +3191,11 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3058
3191
|
create_custom_forward(resnet), hidden_states, temb
|
3059
3192
|
)
|
3060
3193
|
else:
|
3061
|
-
hidden_states = resnet(hidden_states, temb
|
3194
|
+
hidden_states = resnet(hidden_states, temb)
|
3062
3195
|
|
3063
3196
|
if self.upsamplers is not None:
|
3064
3197
|
for upsampler in self.upsamplers:
|
3065
|
-
hidden_states = upsampler(hidden_states, temb
|
3198
|
+
hidden_states = upsampler(hidden_states, temb)
|
3066
3199
|
|
3067
3200
|
return hidden_states
|
3068
3201
|
|
@@ -3178,8 +3311,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3178
3311
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
3179
3312
|
) -> torch.FloatTensor:
|
3180
3313
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3314
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
3315
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
3181
3316
|
|
3182
|
-
lora_scale = cross_attention_kwargs.get("scale", 1.0)
|
3183
3317
|
if attention_mask is None:
|
3184
3318
|
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
3185
3319
|
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
@@ -3217,7 +3351,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3217
3351
|
**cross_attention_kwargs,
|
3218
3352
|
)
|
3219
3353
|
else:
|
3220
|
-
hidden_states = resnet(hidden_states, temb
|
3354
|
+
hidden_states = resnet(hidden_states, temb)
|
3221
3355
|
|
3222
3356
|
hidden_states = attn(
|
3223
3357
|
hidden_states,
|
@@ -3228,7 +3362,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3228
3362
|
|
3229
3363
|
if self.upsamplers is not None:
|
3230
3364
|
for upsampler in self.upsamplers:
|
3231
|
-
hidden_states = upsampler(hidden_states, temb
|
3365
|
+
hidden_states = upsampler(hidden_states, temb)
|
3232
3366
|
|
3233
3367
|
return hidden_states
|
3234
3368
|
|
@@ -3289,8 +3423,13 @@ class KUpBlock2D(nn.Module):
|
|
3289
3423
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
3290
3424
|
temb: Optional[torch.FloatTensor] = None,
|
3291
3425
|
upsample_size: Optional[int] = None,
|
3292
|
-
|
3426
|
+
*args,
|
3427
|
+
**kwargs,
|
3293
3428
|
) -> torch.FloatTensor:
|
3429
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
3430
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
3431
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
3432
|
+
|
3294
3433
|
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
3295
3434
|
if res_hidden_states_tuple is not None:
|
3296
3435
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
@@ -3313,7 +3452,7 @@ class KUpBlock2D(nn.Module):
|
|
3313
3452
|
create_custom_forward(resnet), hidden_states, temb
|
3314
3453
|
)
|
3315
3454
|
else:
|
3316
|
-
hidden_states = resnet(hidden_states, temb
|
3455
|
+
hidden_states = resnet(hidden_states, temb)
|
3317
3456
|
|
3318
3457
|
if self.upsamplers is not None:
|
3319
3458
|
for upsampler in self.upsamplers:
|
@@ -3423,7 +3562,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3423
3562
|
if res_hidden_states_tuple is not None:
|
3424
3563
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3425
3564
|
|
3426
|
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
3427
3565
|
for resnet, attn in zip(self.resnets, self.attentions):
|
3428
3566
|
if self.training and self.gradient_checkpointing:
|
3429
3567
|
|
@@ -3452,7 +3590,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3452
3590
|
encoder_attention_mask=encoder_attention_mask,
|
3453
3591
|
)
|
3454
3592
|
else:
|
3455
|
-
hidden_states = resnet(hidden_states, temb
|
3593
|
+
hidden_states = resnet(hidden_states, temb)
|
3456
3594
|
hidden_states = attn(
|
3457
3595
|
hidden_states,
|
3458
3596
|
encoder_hidden_states=encoder_hidden_states,
|
@@ -3555,6 +3693,8 @@ class KAttentionBlock(nn.Module):
|
|
3555
3693
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
3556
3694
|
) -> torch.FloatTensor:
|
3557
3695
|
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
3696
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
3697
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
3558
3698
|
|
3559
3699
|
# 1. Self-Attention
|
3560
3700
|
if self.add_self_attention:
|