diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +7 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +274 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
- diffusers-0.27.0.dist-info/RECORD +399 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
- diffusers-0.26.3.dist-info/RECORD +0 -384
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
diffusers/loaders/unet.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -37,10 +37,16 @@ from ..utils import (
|
|
37
37
|
_get_model_file,
|
38
38
|
delete_adapter_layers,
|
39
39
|
is_accelerate_available,
|
40
|
+
is_torch_version,
|
40
41
|
logging,
|
41
42
|
set_adapter_layers,
|
42
43
|
set_weights_and_activate_adapters,
|
43
44
|
)
|
45
|
+
from .single_file_utils import (
|
46
|
+
convert_stable_cascade_unet_single_file_to_diffusers,
|
47
|
+
infer_stable_cascade_single_file_config,
|
48
|
+
load_single_file_model_checkpoint,
|
49
|
+
)
|
44
50
|
from .utils import AttnProcsLayers
|
45
51
|
|
46
52
|
|
@@ -168,15 +174,6 @@ class UNet2DConditionLoadersMixin:
|
|
168
174
|
"framework": "pytorch",
|
169
175
|
}
|
170
176
|
|
171
|
-
if low_cpu_mem_usage and not is_accelerate_available():
|
172
|
-
low_cpu_mem_usage = False
|
173
|
-
logger.warning(
|
174
|
-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
175
|
-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
176
|
-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
177
|
-
" install accelerate\n```\n."
|
178
|
-
)
|
179
|
-
|
180
177
|
model_file = None
|
181
178
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
182
179
|
# Let's first try to load .safetensors weights
|
@@ -353,7 +350,7 @@ class UNet2DConditionLoadersMixin:
|
|
353
350
|
is_model_cpu_offload = False
|
354
351
|
is_sequential_cpu_offload = False
|
355
352
|
|
356
|
-
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `
|
353
|
+
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
|
357
354
|
if not USE_PEFT_BACKEND:
|
358
355
|
if _pipeline is not None:
|
359
356
|
for _, component in _pipeline.components.items():
|
@@ -392,7 +389,7 @@ class UNet2DConditionLoadersMixin:
|
|
392
389
|
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
393
390
|
if is_text_encoder_present:
|
394
391
|
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
395
|
-
logger.
|
392
|
+
logger.warning(warn_message)
|
396
393
|
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
397
394
|
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
398
395
|
|
@@ -694,9 +691,29 @@ class UNet2DConditionLoadersMixin:
|
|
694
691
|
if hasattr(self, "peft_config"):
|
695
692
|
self.peft_config.pop(adapter_name, None)
|
696
693
|
|
697
|
-
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
|
694
|
+
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
695
|
+
if low_cpu_mem_usage:
|
696
|
+
if is_accelerate_available():
|
697
|
+
from accelerate import init_empty_weights
|
698
|
+
|
699
|
+
else:
|
700
|
+
low_cpu_mem_usage = False
|
701
|
+
logger.warning(
|
702
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
703
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
704
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
705
|
+
" install accelerate\n```\n."
|
706
|
+
)
|
707
|
+
|
708
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
709
|
+
raise NotImplementedError(
|
710
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
711
|
+
" `low_cpu_mem_usage=False`."
|
712
|
+
)
|
713
|
+
|
698
714
|
updated_state_dict = {}
|
699
715
|
image_projection = None
|
716
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
700
717
|
|
701
718
|
if "proj.weight" in state_dict:
|
702
719
|
# IP-Adapter
|
@@ -704,11 +721,12 @@ class UNet2DConditionLoadersMixin:
|
|
704
721
|
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
705
722
|
cross_attention_dim = state_dict["proj.weight"].shape[0] // 4
|
706
723
|
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
724
|
+
with init_context():
|
725
|
+
image_projection = ImageProjection(
|
726
|
+
cross_attention_dim=cross_attention_dim,
|
727
|
+
image_embed_dim=clip_embeddings_dim,
|
728
|
+
num_image_text_embeds=num_image_text_embeds,
|
729
|
+
)
|
712
730
|
|
713
731
|
for key, value in state_dict.items():
|
714
732
|
diffusers_name = key.replace("proj", "image_embeds")
|
@@ -719,9 +737,10 @@ class UNet2DConditionLoadersMixin:
|
|
719
737
|
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
720
738
|
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
721
739
|
|
722
|
-
|
723
|
-
|
724
|
-
|
740
|
+
with init_context():
|
741
|
+
image_projection = IPAdapterFullImageProjection(
|
742
|
+
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
|
743
|
+
)
|
725
744
|
|
726
745
|
for key, value in state_dict.items():
|
727
746
|
diffusers_name = key.replace("proj.0", "ff.net.0.proj")
|
@@ -737,13 +756,14 @@ class UNet2DConditionLoadersMixin:
|
|
737
756
|
hidden_dims = state_dict["latents"].shape[2]
|
738
757
|
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
739
758
|
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
759
|
+
with init_context():
|
760
|
+
image_projection = IPAdapterPlusImageProjection(
|
761
|
+
embed_dims=embed_dims,
|
762
|
+
output_dims=output_dims,
|
763
|
+
hidden_dims=hidden_dims,
|
764
|
+
heads=heads,
|
765
|
+
num_queries=num_image_text_embeds,
|
766
|
+
)
|
747
767
|
|
748
768
|
for key, value in state_dict.items():
|
749
769
|
diffusers_name = key.replace("0.to", "2.to")
|
@@ -765,10 +785,14 @@ class UNet2DConditionLoadersMixin:
|
|
765
785
|
else:
|
766
786
|
updated_state_dict[diffusers_name] = value
|
767
787
|
|
768
|
-
|
788
|
+
if not low_cpu_mem_usage:
|
789
|
+
image_projection.load_state_dict(updated_state_dict)
|
790
|
+
else:
|
791
|
+
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
792
|
+
|
769
793
|
return image_projection
|
770
794
|
|
771
|
-
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts):
|
795
|
+
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
772
796
|
from ..models.attention_processor import (
|
773
797
|
AttnProcessor,
|
774
798
|
AttnProcessor2_0,
|
@@ -776,9 +800,29 @@ class UNet2DConditionLoadersMixin:
|
|
776
800
|
IPAdapterAttnProcessor2_0,
|
777
801
|
)
|
778
802
|
|
803
|
+
if low_cpu_mem_usage:
|
804
|
+
if is_accelerate_available():
|
805
|
+
from accelerate import init_empty_weights
|
806
|
+
|
807
|
+
else:
|
808
|
+
low_cpu_mem_usage = False
|
809
|
+
logger.warning(
|
810
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
811
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
812
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
813
|
+
" install accelerate\n```\n."
|
814
|
+
)
|
815
|
+
|
816
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
817
|
+
raise NotImplementedError(
|
818
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
819
|
+
" `low_cpu_mem_usage=False`."
|
820
|
+
)
|
821
|
+
|
779
822
|
# set ip-adapter cross-attention processors & load state_dict
|
780
823
|
attn_procs = {}
|
781
824
|
key_id = 1
|
825
|
+
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
782
826
|
for name in self.attn_processors.keys():
|
783
827
|
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
|
784
828
|
if name.startswith("mid_block"):
|
@@ -811,39 +855,149 @@ class UNet2DConditionLoadersMixin:
|
|
811
855
|
# IP-Adapter Plus
|
812
856
|
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
|
813
857
|
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
858
|
+
with init_context():
|
859
|
+
attn_procs[name] = attn_processor_class(
|
860
|
+
hidden_size=hidden_size,
|
861
|
+
cross_attention_dim=cross_attention_dim,
|
862
|
+
scale=1.0,
|
863
|
+
num_tokens=num_image_text_embeds,
|
864
|
+
)
|
820
865
|
|
821
866
|
value_dict = {}
|
822
867
|
for i, state_dict in enumerate(state_dicts):
|
823
868
|
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
824
869
|
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
825
870
|
|
826
|
-
|
871
|
+
if not low_cpu_mem_usage:
|
872
|
+
attn_procs[name].load_state_dict(value_dict)
|
873
|
+
else:
|
874
|
+
device = next(iter(value_dict.values())).device
|
875
|
+
dtype = next(iter(value_dict.values())).dtype
|
876
|
+
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
877
|
+
|
827
878
|
key_id += 2
|
828
879
|
|
829
880
|
return attn_procs
|
830
881
|
|
831
|
-
def _load_ip_adapter_weights(self, state_dicts):
|
882
|
+
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
832
883
|
if not isinstance(state_dicts, list):
|
833
884
|
state_dicts = [state_dicts]
|
834
885
|
# Set encoder_hid_proj after loading ip_adapter weights,
|
835
886
|
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
|
836
887
|
self.encoder_hid_proj = None
|
837
888
|
|
838
|
-
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts)
|
889
|
+
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
839
890
|
self.set_attn_processor(attn_procs)
|
840
891
|
|
841
892
|
# convert IP-Adapter Image Projection layers to diffusers
|
842
893
|
image_projection_layers = []
|
843
894
|
for state_dict in state_dicts:
|
844
|
-
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
845
|
-
|
895
|
+
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
896
|
+
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
897
|
+
)
|
846
898
|
image_projection_layers.append(image_projection_layer)
|
847
899
|
|
848
900
|
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
849
901
|
self.config.encoder_hid_dim_type = "ip_image_proj"
|
902
|
+
|
903
|
+
self.to(dtype=self.dtype, device=self.device)
|
904
|
+
|
905
|
+
|
906
|
+
class FromOriginalUNetMixin:
|
907
|
+
"""
|
908
|
+
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
|
909
|
+
"""
|
910
|
+
|
911
|
+
@classmethod
|
912
|
+
@validate_hf_hub_args
|
913
|
+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
914
|
+
r"""
|
915
|
+
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
|
916
|
+
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
917
|
+
|
918
|
+
Parameters:
|
919
|
+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
920
|
+
Can be either:
|
921
|
+
- A link to the `.ckpt` file (for example
|
922
|
+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
923
|
+
- A path to a *file* containing all pipeline weights.
|
924
|
+
config: (`dict`, *optional*):
|
925
|
+
Dictionary containing the configuration of the model:
|
926
|
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
927
|
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
928
|
+
dtype is automatically derived from the model's weights.
|
929
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
930
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
931
|
+
cached versions if they exist.
|
932
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
933
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
934
|
+
is not used.
|
935
|
+
resume_download (`bool`, *optional*, defaults to `False`):
|
936
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
937
|
+
incompletely downloaded files are deleted.
|
938
|
+
proxies (`Dict[str, str]`, *optional*):
|
939
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
940
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
941
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
942
|
+
Whether to only load local model weights and configuration files or not. If set to True, the model
|
943
|
+
won't be downloaded from the Hub.
|
944
|
+
token (`str` or *bool*, *optional*):
|
945
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
946
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
947
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
948
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
949
|
+
allowed by Git.
|
950
|
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
951
|
+
Can be used to overwrite load and saveable variables of the model.
|
952
|
+
|
953
|
+
"""
|
954
|
+
class_name = cls.__name__
|
955
|
+
if class_name != "StableCascadeUNet":
|
956
|
+
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
|
957
|
+
|
958
|
+
config = kwargs.pop("config", None)
|
959
|
+
resume_download = kwargs.pop("resume_download", False)
|
960
|
+
force_download = kwargs.pop("force_download", False)
|
961
|
+
proxies = kwargs.pop("proxies", None)
|
962
|
+
token = kwargs.pop("token", None)
|
963
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
964
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
965
|
+
revision = kwargs.pop("revision", None)
|
966
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
967
|
+
|
968
|
+
checkpoint = load_single_file_model_checkpoint(
|
969
|
+
pretrained_model_link_or_path,
|
970
|
+
resume_download=resume_download,
|
971
|
+
force_download=force_download,
|
972
|
+
proxies=proxies,
|
973
|
+
token=token,
|
974
|
+
cache_dir=cache_dir,
|
975
|
+
local_files_only=local_files_only,
|
976
|
+
revision=revision,
|
977
|
+
)
|
978
|
+
|
979
|
+
if config is None:
|
980
|
+
config = infer_stable_cascade_single_file_config(checkpoint)
|
981
|
+
model_config = cls.load_config(**config, **kwargs)
|
982
|
+
else:
|
983
|
+
model_config = config
|
984
|
+
|
985
|
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
986
|
+
with ctx():
|
987
|
+
model = cls.from_config(model_config, **kwargs)
|
988
|
+
|
989
|
+
diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
|
990
|
+
if is_accelerate_available():
|
991
|
+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
992
|
+
if len(unexpected_keys) > 0:
|
993
|
+
logger.warn(
|
994
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
995
|
+
)
|
996
|
+
|
997
|
+
else:
|
998
|
+
model.load_state_dict(diffusers_format_checkpoint)
|
999
|
+
|
1000
|
+
if torch_dtype is not None:
|
1001
|
+
model.to(torch_dtype)
|
1002
|
+
|
1003
|
+
return model
|
diffusers/loaders/utils.py
CHANGED
diffusers/models/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -47,6 +47,7 @@ if is_torch_available():
|
|
47
47
|
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
|
48
48
|
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
|
49
49
|
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
50
|
+
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
50
51
|
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
51
52
|
_import_structure["vq_model"] = ["VQModel"]
|
52
53
|
|
@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
80
81
|
I2VGenXLUNet,
|
81
82
|
Kandinsky3UNet,
|
82
83
|
MotionAdapter,
|
84
|
+
StableCascadeUNet,
|
83
85
|
UNet1DModel,
|
84
86
|
UNet2DConditionModel,
|
85
87
|
UNet2DModel,
|
diffusers/models/activations.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 HuggingFace Inc.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -17,8 +17,7 @@ import torch
|
|
17
17
|
import torch.nn.functional as F
|
18
18
|
from torch import nn
|
19
19
|
|
20
|
-
from ..utils import
|
21
|
-
from .lora import LoRACompatibleLinear
|
20
|
+
from ..utils import deprecate
|
22
21
|
|
23
22
|
|
24
23
|
ACTIVATION_FUNCTIONS = {
|
@@ -87,9 +86,7 @@ class GEGLU(nn.Module):
|
|
87
86
|
|
88
87
|
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
89
88
|
super().__init__()
|
90
|
-
|
91
|
-
|
92
|
-
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
|
89
|
+
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
93
90
|
|
94
91
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
95
92
|
if gate.device.type != "mps":
|
@@ -97,9 +94,12 @@ class GEGLU(nn.Module):
|
|
97
94
|
# mps: gelu is not implemented for float16
|
98
95
|
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
99
96
|
|
100
|
-
def forward(self, hidden_states,
|
101
|
-
|
102
|
-
|
97
|
+
def forward(self, hidden_states, *args, **kwargs):
|
98
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
99
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
100
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
101
|
+
|
102
|
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
103
103
|
return hidden_states * self.gelu(gate)
|
104
104
|
|
105
105
|
|
diffusers/models/attention.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -17,18 +17,18 @@ import torch
|
|
17
17
|
import torch.nn.functional as F
|
18
18
|
from torch import nn
|
19
19
|
|
20
|
-
from ..utils import
|
20
|
+
from ..utils import deprecate, logging
|
21
21
|
from ..utils.torch_utils import maybe_allow_in_graph
|
22
22
|
from .activations import GEGLU, GELU, ApproximateGELU
|
23
23
|
from .attention_processor import Attention
|
24
24
|
from .embeddings import SinusoidalPositionalEmbedding
|
25
|
-
from .lora import LoRACompatibleLinear
|
26
25
|
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
27
26
|
|
28
27
|
|
29
|
-
|
30
|
-
|
31
|
-
|
28
|
+
logger = logging.get_logger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32
32
|
# "feed_forward_chunk_size" can be used to save memory
|
33
33
|
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
34
34
|
raise ValueError(
|
@@ -36,18 +36,10 @@ def _chunked_feed_forward(
|
|
36
36
|
)
|
37
37
|
|
38
38
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
)
|
44
|
-
else:
|
45
|
-
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
|
46
|
-
ff_output = torch.cat(
|
47
|
-
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
48
|
-
dim=chunk_dim,
|
49
|
-
)
|
50
|
-
|
39
|
+
ff_output = torch.cat(
|
40
|
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
41
|
+
dim=chunk_dim,
|
42
|
+
)
|
51
43
|
return ff_output
|
52
44
|
|
53
45
|
|
@@ -143,7 +135,7 @@ class BasicTransformerBlock(nn.Module):
|
|
143
135
|
double_self_attention: bool = False,
|
144
136
|
upcast_attention: bool = False,
|
145
137
|
norm_elementwise_affine: bool = True,
|
146
|
-
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'layer_norm_i2vgen'
|
138
|
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
147
139
|
norm_eps: float = 1e-5,
|
148
140
|
final_dropout: bool = False,
|
149
141
|
attention_type: str = "default",
|
@@ -158,6 +150,7 @@ class BasicTransformerBlock(nn.Module):
|
|
158
150
|
super().__init__()
|
159
151
|
self.only_cross_attention = only_cross_attention
|
160
152
|
|
153
|
+
# We keep these boolean flags for backward-compatibility.
|
161
154
|
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
162
155
|
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
163
156
|
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
@@ -298,6 +291,10 @@ class BasicTransformerBlock(nn.Module):
|
|
298
291
|
class_labels: Optional[torch.LongTensor] = None,
|
299
292
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
300
293
|
) -> torch.FloatTensor:
|
294
|
+
if cross_attention_kwargs is not None:
|
295
|
+
if cross_attention_kwargs.get("scale", None) is not None:
|
296
|
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
|
297
|
+
|
301
298
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
302
299
|
# 0. Self-Attention
|
303
300
|
batch_size = hidden_states.shape[0]
|
@@ -325,10 +322,7 @@ class BasicTransformerBlock(nn.Module):
|
|
325
322
|
if self.pos_embed is not None:
|
326
323
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
327
324
|
|
328
|
-
# 1.
|
329
|
-
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
330
|
-
|
331
|
-
# 2. Prepare GLIGEN inputs
|
325
|
+
# 1. Prepare GLIGEN inputs
|
332
326
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
333
327
|
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
334
328
|
|
@@ -347,7 +341,7 @@ class BasicTransformerBlock(nn.Module):
|
|
347
341
|
if hidden_states.ndim == 4:
|
348
342
|
hidden_states = hidden_states.squeeze(1)
|
349
343
|
|
350
|
-
# 2
|
344
|
+
# 1.2 GLIGEN Control
|
351
345
|
if gligen_kwargs is not None:
|
352
346
|
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
353
347
|
|
@@ -393,11 +387,9 @@ class BasicTransformerBlock(nn.Module):
|
|
393
387
|
|
394
388
|
if self._chunk_size is not None:
|
395
389
|
# "feed_forward_chunk_size" can be used to save memory
|
396
|
-
ff_output = _chunked_feed_forward(
|
397
|
-
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
|
398
|
-
)
|
390
|
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
399
391
|
else:
|
400
|
-
ff_output = self.ff(norm_hidden_states
|
392
|
+
ff_output = self.ff(norm_hidden_states)
|
401
393
|
|
402
394
|
if self.norm_type == "ada_norm_zero":
|
403
395
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
@@ -439,7 +431,6 @@ class TemporalBasicTransformerBlock(nn.Module):
|
|
439
431
|
|
440
432
|
# Define 3 blocks. Each block has its own normalization layer.
|
441
433
|
# 1. Self-Attn
|
442
|
-
self.norm_in = nn.LayerNorm(dim)
|
443
434
|
self.ff_in = FeedForward(
|
444
435
|
dim,
|
445
436
|
dim_out=time_mix_inner_dim,
|
@@ -643,7 +634,7 @@ class FeedForward(nn.Module):
|
|
643
634
|
if inner_dim is None:
|
644
635
|
inner_dim = int(dim * mult)
|
645
636
|
dim_out = dim_out if dim_out is not None else dim
|
646
|
-
linear_cls =
|
637
|
+
linear_cls = nn.Linear
|
647
638
|
|
648
639
|
if activation_fn == "gelu":
|
649
640
|
act_fn = GELU(dim, inner_dim, bias=bias)
|
@@ -665,11 +656,10 @@ class FeedForward(nn.Module):
|
|
665
656
|
if final_dropout:
|
666
657
|
self.net.append(nn.Dropout(dropout))
|
667
658
|
|
668
|
-
def forward(self, hidden_states: torch.Tensor,
|
669
|
-
|
659
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
660
|
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
661
|
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
662
|
+
deprecate("scale", "1.0.0", deprecation_message)
|
670
663
|
for module in self.net:
|
671
|
-
|
672
|
-
hidden_states = module(hidden_states, scale)
|
673
|
-
else:
|
674
|
-
hidden_states = module(hidden_states)
|
664
|
+
hidden_states = module(hidden_states)
|
675
665
|
return hidden_states
|