diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import os
|
|
20
20
|
import re
|
21
21
|
from collections import OrderedDict
|
22
22
|
from functools import partial
|
23
|
+
from pathlib import Path
|
23
24
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
24
25
|
|
25
26
|
import safetensors
|
@@ -32,7 +33,6 @@ from .. import __version__
|
|
32
33
|
from ..utils import (
|
33
34
|
CONFIG_NAME,
|
34
35
|
FLAX_WEIGHTS_NAME,
|
35
|
-
SAFETENSORS_FILE_EXTENSION,
|
36
36
|
SAFETENSORS_WEIGHTS_NAME,
|
37
37
|
WEIGHTS_NAME,
|
38
38
|
_add_variant,
|
@@ -43,6 +43,12 @@ from ..utils import (
|
|
43
43
|
logging,
|
44
44
|
)
|
45
45
|
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
46
|
+
from .model_loading_utils import (
|
47
|
+
_determine_device_map,
|
48
|
+
_load_state_dict_into_model,
|
49
|
+
load_model_dict_into_meta,
|
50
|
+
load_state_dict,
|
51
|
+
)
|
46
52
|
|
47
53
|
|
48
54
|
logger = logging.get_logger(__name__)
|
@@ -56,8 +62,6 @@ else:
|
|
56
62
|
|
57
63
|
if is_accelerate_available():
|
58
64
|
import accelerate
|
59
|
-
from accelerate.utils import set_module_tensor_to_device
|
60
|
-
from accelerate.utils.versions import is_torch_version
|
61
65
|
|
62
66
|
|
63
67
|
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
@@ -98,89 +102,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
98
102
|
return first_tuple[1].dtype
|
99
103
|
|
100
104
|
|
101
|
-
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
102
|
-
"""
|
103
|
-
Reads a checkpoint file, returning properly formatted errors if they arise.
|
104
|
-
"""
|
105
|
-
try:
|
106
|
-
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
107
|
-
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
108
|
-
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
109
|
-
else:
|
110
|
-
return torch.load(checkpoint_file, map_location="cpu")
|
111
|
-
except Exception as e:
|
112
|
-
try:
|
113
|
-
with open(checkpoint_file) as f:
|
114
|
-
if f.read().startswith("version"):
|
115
|
-
raise OSError(
|
116
|
-
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
117
|
-
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
118
|
-
"you cloned."
|
119
|
-
)
|
120
|
-
else:
|
121
|
-
raise ValueError(
|
122
|
-
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
123
|
-
"model. Make sure you have saved the model properly."
|
124
|
-
) from e
|
125
|
-
except (UnicodeDecodeError, ValueError):
|
126
|
-
raise OSError(
|
127
|
-
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
128
|
-
)
|
129
|
-
|
130
|
-
|
131
|
-
def load_model_dict_into_meta(
|
132
|
-
model,
|
133
|
-
state_dict: OrderedDict,
|
134
|
-
device: Optional[Union[str, torch.device]] = None,
|
135
|
-
dtype: Optional[Union[str, torch.dtype]] = None,
|
136
|
-
model_name_or_path: Optional[str] = None,
|
137
|
-
) -> List[str]:
|
138
|
-
device = device or torch.device("cpu")
|
139
|
-
dtype = dtype or torch.float32
|
140
|
-
|
141
|
-
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
142
|
-
|
143
|
-
unexpected_keys = []
|
144
|
-
empty_state_dict = model.state_dict()
|
145
|
-
for param_name, param in state_dict.items():
|
146
|
-
if param_name not in empty_state_dict:
|
147
|
-
unexpected_keys.append(param_name)
|
148
|
-
continue
|
149
|
-
|
150
|
-
if empty_state_dict[param_name].shape != param.shape:
|
151
|
-
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
152
|
-
raise ValueError(
|
153
|
-
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
154
|
-
)
|
155
|
-
|
156
|
-
if accepts_dtype:
|
157
|
-
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
158
|
-
else:
|
159
|
-
set_module_tensor_to_device(model, param_name, device, value=param)
|
160
|
-
return unexpected_keys
|
161
|
-
|
162
|
-
|
163
|
-
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
164
|
-
# Convert old format to new format if needed from a PyTorch state_dict
|
165
|
-
# copy state_dict so _load_from_state_dict can modify it
|
166
|
-
state_dict = state_dict.copy()
|
167
|
-
error_msgs = []
|
168
|
-
|
169
|
-
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
170
|
-
# so we need to apply the function recursively.
|
171
|
-
def load(module: torch.nn.Module, prefix: str = ""):
|
172
|
-
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
173
|
-
module._load_from_state_dict(*args)
|
174
|
-
|
175
|
-
for name, child in module._modules.items():
|
176
|
-
if child is not None:
|
177
|
-
load(child, prefix + name + ".")
|
178
|
-
|
179
|
-
load(model_to_load)
|
180
|
-
|
181
|
-
return error_msgs
|
182
|
-
|
183
|
-
|
184
105
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
185
106
|
r"""
|
186
107
|
Base class for all models.
|
@@ -195,6 +116,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
195
116
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
196
117
|
_supports_gradient_checkpointing = False
|
197
118
|
_keys_to_ignore_on_load_unexpected = None
|
119
|
+
_no_split_modules = None
|
198
120
|
|
199
121
|
def __init__(self):
|
200
122
|
super().__init__()
|
@@ -241,6 +163,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
241
163
|
if self._supports_gradient_checkpointing:
|
242
164
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
243
165
|
|
166
|
+
def set_use_npu_flash_attention(self, valid: bool) -> None:
|
167
|
+
r"""
|
168
|
+
Set the switch for the npu flash attention.
|
169
|
+
"""
|
170
|
+
|
171
|
+
def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
|
172
|
+
if hasattr(module, "set_use_npu_flash_attention"):
|
173
|
+
module.set_use_npu_flash_attention(valid)
|
174
|
+
|
175
|
+
for child in module.children():
|
176
|
+
fn_recursive_set_npu_flash_attention(child)
|
177
|
+
|
178
|
+
for module in self.children():
|
179
|
+
if isinstance(module, torch.nn.Module):
|
180
|
+
fn_recursive_set_npu_flash_attention(module)
|
181
|
+
|
182
|
+
def enable_npu_flash_attention(self) -> None:
|
183
|
+
r"""
|
184
|
+
Enable npu flash attention from torch_npu
|
185
|
+
|
186
|
+
"""
|
187
|
+
self.set_use_npu_flash_attention(True)
|
188
|
+
|
189
|
+
def disable_npu_flash_attention(self) -> None:
|
190
|
+
r"""
|
191
|
+
disable npu flash attention from torch_npu
|
192
|
+
|
193
|
+
"""
|
194
|
+
self.set_use_npu_flash_attention(False)
|
195
|
+
|
244
196
|
def set_use_memory_efficient_attention_xformers(
|
245
197
|
self, valid: bool, attention_op: Optional[Callable] = None
|
246
198
|
) -> None:
|
@@ -367,18 +319,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
367
319
|
# Save the model
|
368
320
|
if safe_serialization:
|
369
321
|
safetensors.torch.save_file(
|
370
|
-
state_dict,
|
322
|
+
state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
|
371
323
|
)
|
372
324
|
else:
|
373
|
-
torch.save(state_dict,
|
325
|
+
torch.save(state_dict, Path(save_directory, weights_name).as_posix())
|
374
326
|
|
375
|
-
logger.info(f"Model weights saved in {
|
327
|
+
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
|
376
328
|
|
377
329
|
if push_to_hub:
|
378
330
|
# Create a new empty model card and eventually tag it
|
379
331
|
model_card = load_or_create_model_card(repo_id, token=token)
|
380
332
|
model_card = populate_model_card(model_card)
|
381
|
-
model_card.save(
|
333
|
+
model_card.save(Path(save_directory, "README.md").as_posix())
|
382
334
|
|
383
335
|
self._upload_folder(
|
384
336
|
save_directory,
|
@@ -415,9 +367,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
415
367
|
force_download (`bool`, *optional*, defaults to `False`):
|
416
368
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
417
369
|
cached versions if they exist.
|
418
|
-
resume_download
|
419
|
-
|
420
|
-
|
370
|
+
resume_download:
|
371
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
372
|
+
of Diffusers.
|
421
373
|
proxies (`Dict[str, str]`, *optional*):
|
422
374
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
423
375
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -499,7 +451,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
499
451
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
500
452
|
force_download = kwargs.pop("force_download", False)
|
501
453
|
from_flax = kwargs.pop("from_flax", False)
|
502
|
-
resume_download = kwargs.pop("resume_download",
|
454
|
+
resume_download = kwargs.pop("resume_download", None)
|
503
455
|
proxies = kwargs.pop("proxies", None)
|
504
456
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
505
457
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -554,6 +506,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
554
506
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
555
507
|
)
|
556
508
|
|
509
|
+
# change device_map into a map if we passed an int, a str or a torch.device
|
510
|
+
if isinstance(device_map, torch.device):
|
511
|
+
device_map = {"": device_map}
|
512
|
+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
513
|
+
try:
|
514
|
+
device_map = {"": torch.device(device_map)}
|
515
|
+
except RuntimeError:
|
516
|
+
raise ValueError(
|
517
|
+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
|
518
|
+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
|
519
|
+
)
|
520
|
+
elif isinstance(device_map, int):
|
521
|
+
if device_map < 0:
|
522
|
+
raise ValueError(
|
523
|
+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
|
524
|
+
)
|
525
|
+
else:
|
526
|
+
device_map = {"": device_map}
|
527
|
+
|
528
|
+
if device_map is not None:
|
529
|
+
if low_cpu_mem_usage is None:
|
530
|
+
low_cpu_mem_usage = True
|
531
|
+
elif not low_cpu_mem_usage:
|
532
|
+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
|
533
|
+
|
534
|
+
if low_cpu_mem_usage:
|
535
|
+
if device_map is not None and not is_torch_version(">=", "1.10"):
|
536
|
+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
|
537
|
+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
|
538
|
+
|
557
539
|
# Load config if we don't provide a configuration
|
558
540
|
config_path = pretrained_model_name_or_path
|
559
541
|
|
@@ -576,10 +558,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
576
558
|
token=token,
|
577
559
|
revision=revision,
|
578
560
|
subfolder=subfolder,
|
579
|
-
device_map=device_map,
|
580
|
-
max_memory=max_memory,
|
581
|
-
offload_folder=offload_folder,
|
582
|
-
offload_state_dict=offload_state_dict,
|
583
561
|
user_agent=user_agent,
|
584
562
|
**kwargs,
|
585
563
|
)
|
@@ -684,6 +662,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
684
662
|
else: # else let accelerate handle loading and dispatching.
|
685
663
|
# Load weights and dispatch according to the device_map
|
686
664
|
# by default the device_map is None and the weights are loaded on the CPU
|
665
|
+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
|
687
666
|
try:
|
688
667
|
accelerate.load_checkpoint_and_dispatch(
|
689
668
|
model,
|
@@ -693,6 +672,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
693
672
|
offload_folder=offload_folder,
|
694
673
|
offload_state_dict=offload_state_dict,
|
695
674
|
dtype=torch_dtype,
|
675
|
+
force_hooks=True,
|
676
|
+
strict=True,
|
696
677
|
)
|
697
678
|
except AttributeError as e:
|
698
679
|
# When using accelerate loading, we do not have the ability to load the state
|
@@ -873,6 +854,45 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
873
854
|
|
874
855
|
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
875
856
|
|
857
|
+
@classmethod
|
858
|
+
def _get_signature_keys(cls, obj):
|
859
|
+
parameters = inspect.signature(obj.__init__).parameters
|
860
|
+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
861
|
+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
862
|
+
expected_modules = set(required_parameters.keys()) - {"self"}
|
863
|
+
|
864
|
+
return expected_modules, optional_parameters
|
865
|
+
|
866
|
+
# Adapted from `transformers` modeling_utils.py
|
867
|
+
def _get_no_split_modules(self, device_map: str):
|
868
|
+
"""
|
869
|
+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
|
870
|
+
get the underlying `_no_split_modules`.
|
871
|
+
|
872
|
+
Args:
|
873
|
+
device_map (`str`):
|
874
|
+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
|
875
|
+
|
876
|
+
Returns:
|
877
|
+
`List[str]`: List of modules that should not be split
|
878
|
+
"""
|
879
|
+
_no_split_modules = set()
|
880
|
+
modules_to_check = [self]
|
881
|
+
while len(modules_to_check) > 0:
|
882
|
+
module = modules_to_check.pop(-1)
|
883
|
+
# if the module does not appear in _no_split_modules, we also check the children
|
884
|
+
if module.__class__.__name__ not in _no_split_modules:
|
885
|
+
if isinstance(module, ModelMixin):
|
886
|
+
if module._no_split_modules is None:
|
887
|
+
raise ValueError(
|
888
|
+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
|
889
|
+
"class needs to implement the `_no_split_modules` attribute."
|
890
|
+
)
|
891
|
+
else:
|
892
|
+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
|
893
|
+
modules_to_check += list(module.children())
|
894
|
+
return list(_no_split_modules)
|
895
|
+
|
876
896
|
@property
|
877
897
|
def device(self) -> torch.device:
|
878
898
|
"""
|
diffusers/models/resnet.py
CHANGED
@@ -58,7 +58,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
58
58
|
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
59
59
|
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
|
60
60
|
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
|
61
|
-
kernel (`torch.
|
61
|
+
kernel (`torch.Tensor`, optional, default to None): FIR filter, see
|
62
62
|
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
63
63
|
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
64
64
|
use_in_shortcut (`bool`, *optional*, default to `True`):
|
@@ -101,8 +101,6 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
101
101
|
self.output_scale_factor = output_scale_factor
|
102
102
|
self.time_embedding_norm = time_embedding_norm
|
103
103
|
|
104
|
-
conv_cls = nn.Conv2d
|
105
|
-
|
106
104
|
if groups_out is None:
|
107
105
|
groups_out = groups
|
108
106
|
|
@@ -113,7 +111,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
113
111
|
else:
|
114
112
|
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
|
115
113
|
|
116
|
-
self.conv1 =
|
114
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
117
115
|
|
118
116
|
if self.time_embedding_norm == "ada_group": # ada_group
|
119
117
|
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
@@ -125,7 +123,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
125
123
|
self.dropout = torch.nn.Dropout(dropout)
|
126
124
|
|
127
125
|
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
128
|
-
self.conv2 =
|
126
|
+
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
129
127
|
|
130
128
|
self.nonlinearity = get_activation(non_linearity)
|
131
129
|
|
@@ -139,7 +137,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
139
137
|
|
140
138
|
self.conv_shortcut = None
|
141
139
|
if self.use_in_shortcut:
|
142
|
-
self.conv_shortcut =
|
140
|
+
self.conv_shortcut = nn.Conv2d(
|
143
141
|
in_channels,
|
144
142
|
conv_2d_out_channels,
|
145
143
|
kernel_size=1,
|
@@ -148,7 +146,7 @@ class ResnetBlockCondNorm2D(nn.Module):
|
|
148
146
|
bias=conv_shortcut_bias,
|
149
147
|
)
|
150
148
|
|
151
|
-
def forward(self, input_tensor: torch.
|
149
|
+
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
152
150
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
153
151
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
154
152
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -204,9 +202,9 @@ class ResnetBlock2D(nn.Module):
|
|
204
202
|
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
205
203
|
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
206
204
|
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
207
|
-
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
|
208
|
-
|
209
|
-
kernel (`torch.
|
205
|
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
|
206
|
+
stronger conditioning with scale and shift.
|
207
|
+
kernel (`torch.Tensor`, optional, default to None): FIR filter, see
|
210
208
|
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
211
209
|
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
212
210
|
use_in_shortcut (`bool`, *optional*, default to `True`):
|
@@ -234,7 +232,7 @@ class ResnetBlock2D(nn.Module):
|
|
234
232
|
non_linearity: str = "swish",
|
235
233
|
skip_time_act: bool = False,
|
236
234
|
time_embedding_norm: str = "default", # default, scale_shift,
|
237
|
-
kernel: Optional[torch.
|
235
|
+
kernel: Optional[torch.Tensor] = None,
|
238
236
|
output_scale_factor: float = 1.0,
|
239
237
|
use_in_shortcut: Optional[bool] = None,
|
240
238
|
up: bool = False,
|
@@ -263,21 +261,18 @@ class ResnetBlock2D(nn.Module):
|
|
263
261
|
self.time_embedding_norm = time_embedding_norm
|
264
262
|
self.skip_time_act = skip_time_act
|
265
263
|
|
266
|
-
linear_cls = nn.Linear
|
267
|
-
conv_cls = nn.Conv2d
|
268
|
-
|
269
264
|
if groups_out is None:
|
270
265
|
groups_out = groups
|
271
266
|
|
272
267
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
273
268
|
|
274
|
-
self.conv1 =
|
269
|
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
275
270
|
|
276
271
|
if temb_channels is not None:
|
277
272
|
if self.time_embedding_norm == "default":
|
278
|
-
self.time_emb_proj =
|
273
|
+
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
279
274
|
elif self.time_embedding_norm == "scale_shift":
|
280
|
-
self.time_emb_proj =
|
275
|
+
self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
|
281
276
|
else:
|
282
277
|
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
283
278
|
else:
|
@@ -287,7 +282,7 @@ class ResnetBlock2D(nn.Module):
|
|
287
282
|
|
288
283
|
self.dropout = torch.nn.Dropout(dropout)
|
289
284
|
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
290
|
-
self.conv2 =
|
285
|
+
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
291
286
|
|
292
287
|
self.nonlinearity = get_activation(non_linearity)
|
293
288
|
|
@@ -313,7 +308,7 @@ class ResnetBlock2D(nn.Module):
|
|
313
308
|
|
314
309
|
self.conv_shortcut = None
|
315
310
|
if self.use_in_shortcut:
|
316
|
-
self.conv_shortcut =
|
311
|
+
self.conv_shortcut = nn.Conv2d(
|
317
312
|
in_channels,
|
318
313
|
conv_2d_out_channels,
|
319
314
|
kernel_size=1,
|
@@ -322,7 +317,7 @@ class ResnetBlock2D(nn.Module):
|
|
322
317
|
bias=conv_shortcut_bias,
|
323
318
|
)
|
324
319
|
|
325
|
-
def forward(self, input_tensor: torch.
|
320
|
+
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
326
321
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
327
322
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
328
323
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -610,7 +605,7 @@ class TemporalResnetBlock(nn.Module):
|
|
610
605
|
padding=0,
|
611
606
|
)
|
612
607
|
|
613
|
-
def forward(self, input_tensor: torch.
|
608
|
+
def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
|
614
609
|
hidden_states = input_tensor
|
615
610
|
|
616
611
|
hidden_states = self.norm1(hidden_states)
|
@@ -690,8 +685,8 @@ class SpatioTemporalResBlock(nn.Module):
|
|
690
685
|
|
691
686
|
def forward(
|
692
687
|
self,
|
693
|
-
hidden_states: torch.
|
694
|
-
temb: Optional[torch.
|
688
|
+
hidden_states: torch.Tensor,
|
689
|
+
temb: Optional[torch.Tensor] = None,
|
695
690
|
image_only_indicator: Optional[torch.Tensor] = None,
|
696
691
|
):
|
697
692
|
num_frames = image_only_indicator.shape[-1]
|
@@ -20,15 +20,15 @@ from .transformers.transformer_temporal import (
|
|
20
20
|
|
21
21
|
|
22
22
|
class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
|
23
|
-
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.
|
23
|
+
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerTemporalModelOutput`, instead."
|
24
24
|
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
25
25
|
|
26
26
|
|
27
27
|
class TransformerTemporalModel(TransformerTemporalModel):
|
28
|
-
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.
|
28
|
+
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel`, instead."
|
29
29
|
deprecate("TransformerTemporalModel", "0.29", deprecation_message)
|
30
30
|
|
31
31
|
|
32
32
|
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
|
33
|
-
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.
|
33
|
+
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel`, instead."
|
34
34
|
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
@@ -106,21 +106,21 @@ class DualTransformer2DModel(nn.Module):
|
|
106
106
|
"""
|
107
107
|
Args:
|
108
108
|
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109
|
-
When continuous, `torch.
|
110
|
-
hidden_states.
|
109
|
+
When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
|
111
110
|
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112
111
|
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113
112
|
self-attention.
|
114
113
|
timestep ( `torch.long`, *optional*):
|
115
114
|
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116
|
-
attention_mask (`torch.
|
115
|
+
attention_mask (`torch.Tensor`, *optional*):
|
117
116
|
Optional attention mask to be applied in Attention.
|
118
117
|
cross_attention_kwargs (`dict`, *optional*):
|
119
118
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
120
119
|
`self.processor` in
|
121
120
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
122
121
|
return_dict (`bool`, *optional*, defaults to `True`):
|
123
|
-
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
122
|
+
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
123
|
+
tuple.
|
124
124
|
|
125
125
|
Returns:
|
126
126
|
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
|
|
26
26
|
The output of [`PriorTransformer`].
|
27
27
|
|
28
28
|
Args:
|
29
|
-
predicted_image_embedding (`torch.
|
29
|
+
predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
30
30
|
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
31
31
|
"""
|
32
32
|
|
33
|
-
predicted_image_embedding: torch.
|
33
|
+
predicted_image_embedding: torch.Tensor
|
34
34
|
|
35
35
|
|
36
36
|
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
246
246
|
self,
|
247
247
|
hidden_states,
|
248
248
|
timestep: Union[torch.Tensor, float, int],
|
249
|
-
proj_embedding: torch.
|
250
|
-
encoder_hidden_states: Optional[torch.
|
249
|
+
proj_embedding: torch.Tensor,
|
250
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
251
251
|
attention_mask: Optional[torch.BoolTensor] = None,
|
252
252
|
return_dict: bool = True,
|
253
253
|
):
|
@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
255
255
|
The [`PriorTransformer`] forward method.
|
256
256
|
|
257
257
|
Args:
|
258
|
-
hidden_states (`torch.
|
258
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
259
259
|
The currently predicted image embeddings.
|
260
260
|
timestep (`torch.LongTensor`):
|
261
261
|
Current denoising step.
|
262
|
-
proj_embedding (`torch.
|
262
|
+
proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
|
263
263
|
Projected embedding vector the denoising process is conditioned on.
|
264
|
-
encoder_hidden_states (`torch.
|
264
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
265
265
|
Hidden states of the text embeddings the denoising process is conditioned on.
|
266
266
|
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
267
267
|
Text mask for the text embeddings.
|
@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
|
|
86
86
|
self.post_dropout = nn.Dropout(p=dropout_rate)
|
87
87
|
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
88
88
|
|
89
|
-
def encoder_decoder_mask(self, query_input: torch.
|
89
|
+
def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
|
90
90
|
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
91
91
|
return mask.unsqueeze(-3)
|
92
92
|
|
@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
|
|
195
195
|
|
196
196
|
def forward(
|
197
197
|
self,
|
198
|
-
hidden_states: torch.
|
199
|
-
conditioning_emb: Optional[torch.
|
200
|
-
attention_mask: Optional[torch.
|
198
|
+
hidden_states: torch.Tensor,
|
199
|
+
conditioning_emb: Optional[torch.Tensor] = None,
|
200
|
+
attention_mask: Optional[torch.Tensor] = None,
|
201
201
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
202
202
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
203
203
|
encoder_decoder_position_bias=None,
|
204
|
-
) -> Tuple[torch.
|
204
|
+
) -> Tuple[torch.Tensor]:
|
205
205
|
hidden_states = self.layer[0](
|
206
206
|
hidden_states,
|
207
207
|
conditioning_emb=conditioning_emb,
|
@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
|
|
249
249
|
|
250
250
|
def forward(
|
251
251
|
self,
|
252
|
-
hidden_states: torch.
|
253
|
-
conditioning_emb: Optional[torch.
|
254
|
-
attention_mask: Optional[torch.
|
255
|
-
) -> torch.
|
252
|
+
hidden_states: torch.Tensor,
|
253
|
+
conditioning_emb: Optional[torch.Tensor] = None,
|
254
|
+
attention_mask: Optional[torch.Tensor] = None,
|
255
|
+
) -> torch.Tensor:
|
256
256
|
# pre_self_attention_layer_norm
|
257
257
|
normed_hidden_states = self.layer_norm(hidden_states)
|
258
258
|
|
@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
|
|
292
292
|
|
293
293
|
def forward(
|
294
294
|
self,
|
295
|
-
hidden_states: torch.
|
296
|
-
key_value_states: Optional[torch.
|
297
|
-
attention_mask: Optional[torch.
|
298
|
-
) -> torch.
|
295
|
+
hidden_states: torch.Tensor,
|
296
|
+
key_value_states: Optional[torch.Tensor] = None,
|
297
|
+
attention_mask: Optional[torch.Tensor] = None,
|
298
|
+
) -> torch.Tensor:
|
299
299
|
normed_hidden_states = self.layer_norm(hidden_states)
|
300
300
|
attention_output = self.attention(
|
301
301
|
normed_hidden_states,
|
@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
|
|
328
328
|
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
329
329
|
self.dropout = nn.Dropout(dropout_rate)
|
330
330
|
|
331
|
-
def forward(
|
332
|
-
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
333
|
-
) -> torch.FloatTensor:
|
331
|
+
def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
334
332
|
forwarded_states = self.layer_norm(hidden_states)
|
335
333
|
if conditioning_emb is not None:
|
336
334
|
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
|
|
361
359
|
self.dropout = nn.Dropout(dropout_rate)
|
362
360
|
self.act = NewGELUActivation()
|
363
361
|
|
364
|
-
def forward(self, hidden_states: torch.
|
362
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
365
363
|
hidden_gelu = self.act(self.wi_0(hidden_states))
|
366
364
|
hidden_linear = self.wi_1(hidden_states)
|
367
365
|
hidden_states = hidden_gelu * hidden_linear
|
@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
|
|
390
388
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
391
389
|
self.variance_epsilon = eps
|
392
390
|
|
393
|
-
def forward(self, hidden_states: torch.
|
391
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
394
392
|
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
395
393
|
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
396
394
|
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
|
|
431
429
|
super().__init__()
|
432
430
|
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
433
431
|
|
434
|
-
def forward(self, x: torch.
|
432
|
+
def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
|
435
433
|
emb = self.scale_bias(conditioning_emb)
|
436
434
|
scale, shift = torch.chunk(emb, 2, -1)
|
437
435
|
x = x * (1 + scale) + shift
|