diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,290 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import inspect
|
15
|
+
import re
|
16
|
+
from contextlib import nullcontext
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
20
|
+
|
21
|
+
from ..utils import deprecate, is_accelerate_available, logging
|
22
|
+
from .single_file_utils import (
|
23
|
+
SingleFileComponentError,
|
24
|
+
convert_controlnet_checkpoint,
|
25
|
+
convert_ldm_unet_checkpoint,
|
26
|
+
convert_ldm_vae_checkpoint,
|
27
|
+
convert_stable_cascade_unet_single_file_to_diffusers,
|
28
|
+
create_controlnet_diffusers_config_from_ldm,
|
29
|
+
create_unet_diffusers_config_from_ldm,
|
30
|
+
create_vae_diffusers_config_from_ldm,
|
31
|
+
fetch_diffusers_config,
|
32
|
+
fetch_original_config,
|
33
|
+
load_single_file_checkpoint,
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
logger = logging.get_logger(__name__)
|
38
|
+
|
39
|
+
|
40
|
+
if is_accelerate_available():
|
41
|
+
from accelerate import init_empty_weights
|
42
|
+
|
43
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
44
|
+
|
45
|
+
|
46
|
+
SINGLE_FILE_LOADABLE_CLASSES = {
|
47
|
+
"StableCascadeUNet": {
|
48
|
+
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
49
|
+
},
|
50
|
+
"UNet2DConditionModel": {
|
51
|
+
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
52
|
+
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
53
|
+
"default_subfolder": "unet",
|
54
|
+
"legacy_kwargs": {
|
55
|
+
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
56
|
+
},
|
57
|
+
},
|
58
|
+
"AutoencoderKL": {
|
59
|
+
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
60
|
+
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
61
|
+
"default_subfolder": "vae",
|
62
|
+
},
|
63
|
+
"ControlNetModel": {
|
64
|
+
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
65
|
+
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
66
|
+
},
|
67
|
+
}
|
68
|
+
|
69
|
+
|
70
|
+
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
71
|
+
parameters = inspect.signature(mapping_fn).parameters
|
72
|
+
|
73
|
+
mapping_kwargs = {}
|
74
|
+
for parameter in parameters:
|
75
|
+
if parameter in kwargs:
|
76
|
+
mapping_kwargs[parameter] = kwargs[parameter]
|
77
|
+
|
78
|
+
return mapping_kwargs
|
79
|
+
|
80
|
+
|
81
|
+
class FromOriginalModelMixin:
|
82
|
+
"""
|
83
|
+
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
84
|
+
"""
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
@validate_hf_hub_args
|
88
|
+
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
|
89
|
+
r"""
|
90
|
+
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
91
|
+
is set in evaluation mode (`model.eval()`) by default.
|
92
|
+
|
93
|
+
Parameters:
|
94
|
+
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
95
|
+
Can be either:
|
96
|
+
- A link to the `.safetensors` or `.ckpt` file (for example
|
97
|
+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
98
|
+
- A path to a local *file* containing the weights of the component model.
|
99
|
+
- A state dict containing the component model weights.
|
100
|
+
config (`str`, *optional*):
|
101
|
+
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
102
|
+
on the Hub.
|
103
|
+
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
104
|
+
configs in Diffusers format.
|
105
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
106
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
107
|
+
original_config (`str`, *optional*):
|
108
|
+
Dict or path to a yaml file containing the configuration for the model in its original format.
|
109
|
+
If a dict is provided, it will be used to initialize the model configuration.
|
110
|
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
111
|
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
112
|
+
dtype is automatically derived from the model's weights.
|
113
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
114
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
115
|
+
cached versions if they exist.
|
116
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
117
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
118
|
+
is not used.
|
119
|
+
resume_download (`bool`, *optional*, defaults to `False`):
|
120
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
121
|
+
incompletely downloaded files are deleted.
|
122
|
+
proxies (`Dict[str, str]`, *optional*):
|
123
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
124
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
125
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
126
|
+
Whether to only load local model weights and configuration files or not. If set to True, the model
|
127
|
+
won't be downloaded from the Hub.
|
128
|
+
token (`str` or *bool*, *optional*):
|
129
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
130
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
131
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
132
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
133
|
+
allowed by Git.
|
134
|
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
135
|
+
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
136
|
+
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
137
|
+
method. See example below for more information.
|
138
|
+
|
139
|
+
```py
|
140
|
+
>>> from diffusers import StableCascadeUNet
|
141
|
+
|
142
|
+
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
143
|
+
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
144
|
+
```
|
145
|
+
"""
|
146
|
+
|
147
|
+
class_name = cls.__name__
|
148
|
+
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
149
|
+
raise ValueError(
|
150
|
+
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
151
|
+
)
|
152
|
+
|
153
|
+
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
154
|
+
if pretrained_model_link_or_path is not None:
|
155
|
+
deprecation_message = (
|
156
|
+
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
157
|
+
)
|
158
|
+
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
159
|
+
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
160
|
+
|
161
|
+
config = kwargs.pop("config", None)
|
162
|
+
original_config = kwargs.pop("original_config", None)
|
163
|
+
|
164
|
+
if config is not None and original_config is not None:
|
165
|
+
raise ValueError(
|
166
|
+
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
167
|
+
)
|
168
|
+
|
169
|
+
resume_download = kwargs.pop("resume_download", None)
|
170
|
+
force_download = kwargs.pop("force_download", False)
|
171
|
+
proxies = kwargs.pop("proxies", None)
|
172
|
+
token = kwargs.pop("token", None)
|
173
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
174
|
+
local_files_only = kwargs.pop("local_files_only", None)
|
175
|
+
subfolder = kwargs.pop("subfolder", None)
|
176
|
+
revision = kwargs.pop("revision", None)
|
177
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
178
|
+
|
179
|
+
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
180
|
+
checkpoint = pretrained_model_link_or_path_or_dict
|
181
|
+
else:
|
182
|
+
checkpoint = load_single_file_checkpoint(
|
183
|
+
pretrained_model_link_or_path_or_dict,
|
184
|
+
resume_download=resume_download,
|
185
|
+
force_download=force_download,
|
186
|
+
proxies=proxies,
|
187
|
+
token=token,
|
188
|
+
cache_dir=cache_dir,
|
189
|
+
local_files_only=local_files_only,
|
190
|
+
revision=revision,
|
191
|
+
)
|
192
|
+
|
193
|
+
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
|
194
|
+
|
195
|
+
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
196
|
+
if original_config:
|
197
|
+
if "config_mapping_fn" in mapping_functions:
|
198
|
+
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
199
|
+
else:
|
200
|
+
config_mapping_fn = None
|
201
|
+
|
202
|
+
if config_mapping_fn is None:
|
203
|
+
raise ValueError(
|
204
|
+
(
|
205
|
+
f"`original_config` has been provided for {class_name} but no mapping function"
|
206
|
+
"was found to convert the original config to a Diffusers config in"
|
207
|
+
"`diffusers.loaders.single_file_utils`"
|
208
|
+
)
|
209
|
+
)
|
210
|
+
|
211
|
+
if isinstance(original_config, str):
|
212
|
+
# If original_config is a URL or filepath fetch the original_config dict
|
213
|
+
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
214
|
+
|
215
|
+
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
216
|
+
diffusers_model_config = config_mapping_fn(
|
217
|
+
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
if config:
|
221
|
+
if isinstance(config, str):
|
222
|
+
default_pretrained_model_config_name = config
|
223
|
+
else:
|
224
|
+
raise ValueError(
|
225
|
+
(
|
226
|
+
"Invalid `config` argument. Please provide a string representing a repo id"
|
227
|
+
"or path to a local Diffusers model repo."
|
228
|
+
)
|
229
|
+
)
|
230
|
+
|
231
|
+
else:
|
232
|
+
config = fetch_diffusers_config(checkpoint)
|
233
|
+
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
234
|
+
|
235
|
+
if "default_subfolder" in mapping_functions:
|
236
|
+
subfolder = mapping_functions["default_subfolder"]
|
237
|
+
|
238
|
+
subfolder = subfolder or config.pop(
|
239
|
+
"subfolder", None
|
240
|
+
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
241
|
+
|
242
|
+
diffusers_model_config = cls.load_config(
|
243
|
+
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
244
|
+
subfolder=subfolder,
|
245
|
+
local_files_only=local_files_only,
|
246
|
+
)
|
247
|
+
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
248
|
+
|
249
|
+
# Map legacy kwargs to new kwargs
|
250
|
+
if "legacy_kwargs" in mapping_functions:
|
251
|
+
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
252
|
+
for legacy_key, new_key in legacy_kwargs.items():
|
253
|
+
if legacy_key in kwargs:
|
254
|
+
kwargs[new_key] = kwargs.pop(legacy_key)
|
255
|
+
|
256
|
+
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
257
|
+
diffusers_model_config.update(model_kwargs)
|
258
|
+
|
259
|
+
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
260
|
+
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
261
|
+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
262
|
+
)
|
263
|
+
if not diffusers_format_checkpoint:
|
264
|
+
raise SingleFileComponentError(
|
265
|
+
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
266
|
+
)
|
267
|
+
|
268
|
+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
269
|
+
with ctx():
|
270
|
+
model = cls.from_config(diffusers_model_config)
|
271
|
+
|
272
|
+
if is_accelerate_available():
|
273
|
+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
274
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
275
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
276
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
277
|
+
|
278
|
+
if len(unexpected_keys) > 0:
|
279
|
+
logger.warning(
|
280
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
model.load_state_dict(diffusers_format_checkpoint)
|
284
|
+
|
285
|
+
if torch_dtype is not None:
|
286
|
+
model.to(torch_dtype)
|
287
|
+
|
288
|
+
model.eval()
|
289
|
+
|
290
|
+
return model
|