diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""Conversion script for the Stable Diffusion checkpoints."""
|
16
16
|
|
17
17
|
import os
|
18
18
|
import re
|
@@ -26,7 +26,6 @@ import yaml
|
|
26
26
|
from ..models.modeling_utils import load_state_dict
|
27
27
|
from ..schedulers import (
|
28
28
|
DDIMScheduler,
|
29
|
-
DDPMScheduler,
|
30
29
|
DPMSolverMultistepScheduler,
|
31
30
|
EDMDPMSolverMultistepScheduler,
|
32
31
|
EulerAncestralDiscreteScheduler,
|
@@ -35,133 +34,85 @@ from ..schedulers import (
|
|
35
34
|
LMSDiscreteScheduler,
|
36
35
|
PNDMScheduler,
|
37
36
|
)
|
38
|
-
from ..utils import
|
37
|
+
from ..utils import (
|
38
|
+
SAFETENSORS_WEIGHTS_NAME,
|
39
|
+
WEIGHTS_NAME,
|
40
|
+
deprecate,
|
41
|
+
is_accelerate_available,
|
42
|
+
is_transformers_available,
|
43
|
+
logging,
|
44
|
+
)
|
39
45
|
from ..utils.hub_utils import _get_model_file
|
40
46
|
|
41
47
|
|
42
48
|
if is_transformers_available():
|
43
|
-
from transformers import
|
44
|
-
CLIPTextConfig,
|
45
|
-
CLIPTextModel,
|
46
|
-
CLIPTextModelWithProjection,
|
47
|
-
CLIPTokenizer,
|
48
|
-
)
|
49
|
+
from transformers import AutoImageProcessor
|
49
50
|
|
50
51
|
if is_accelerate_available():
|
51
52
|
from accelerate import init_empty_weights
|
52
53
|
|
53
|
-
|
54
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
54
55
|
|
55
|
-
|
56
|
-
"v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
|
57
|
-
"v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml",
|
58
|
-
"xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml",
|
59
|
-
"xl_refiner": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml",
|
60
|
-
"upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
|
61
|
-
"controlnet": "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml",
|
62
|
-
}
|
56
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
63
57
|
|
64
58
|
CHECKPOINT_KEY_NAMES = {
|
65
59
|
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
66
60
|
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
67
61
|
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
62
|
+
"upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
|
63
|
+
"controlnet": "control_model.time_embed.0.weight",
|
64
|
+
"playground-v2-5": "edm_mean",
|
65
|
+
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
66
|
+
"clip": "cond_stage_model.transformer.text_model.embeddings.position_ids",
|
67
|
+
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
|
68
|
+
"open_clip": "cond_stage_model.model.token_embedding.weight",
|
69
|
+
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
|
70
|
+
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
|
71
|
+
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
72
|
+
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
68
73
|
}
|
69
74
|
|
70
|
-
|
71
|
-
"
|
72
|
-
"
|
73
|
-
"
|
74
|
-
"
|
75
|
-
"
|
76
|
-
"
|
77
|
-
"
|
78
|
-
"
|
79
|
-
"
|
80
|
-
"
|
81
|
-
"
|
75
|
+
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
76
|
+
"xl_base": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0"},
|
77
|
+
"xl_refiner": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-refiner-1.0"},
|
78
|
+
"xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
|
79
|
+
"playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
|
80
|
+
"upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
|
81
|
+
"inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"},
|
82
|
+
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
|
83
|
+
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
|
84
|
+
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
|
85
|
+
"v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"},
|
86
|
+
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
|
87
|
+
"stable_cascade_stage_b_lite": {
|
88
|
+
"pretrained_model_name_or_path": "stabilityai/stable-cascade",
|
89
|
+
"subfolder": "decoder_lite",
|
90
|
+
},
|
91
|
+
"stable_cascade_stage_c": {
|
92
|
+
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
|
93
|
+
"subfolder": "prior",
|
94
|
+
},
|
95
|
+
"stable_cascade_stage_c_lite": {
|
96
|
+
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
|
97
|
+
"subfolder": "prior_lite",
|
98
|
+
},
|
82
99
|
}
|
83
100
|
|
84
|
-
|
85
|
-
|
86
|
-
"
|
87
|
-
"
|
88
|
-
"
|
89
|
-
"
|
101
|
+
# Use to configure model sample size when original config is provided
|
102
|
+
DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP = {
|
103
|
+
"xl_base": 1024,
|
104
|
+
"xl_refiner": 1024,
|
105
|
+
"xl_inpaint": 1024,
|
106
|
+
"playground-v2-5": 1024,
|
107
|
+
"upscale": 512,
|
108
|
+
"inpainting": 512,
|
109
|
+
"inpainting_v2": 512,
|
110
|
+
"controlnet": 512,
|
111
|
+
"v2": 768,
|
112
|
+
"v1": 512,
|
90
113
|
}
|
91
114
|
|
92
115
|
|
93
|
-
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
|
94
|
-
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
|
95
|
-
|
96
|
-
if is_stage_c:
|
97
|
-
state_dict = {}
|
98
|
-
for key in original_state_dict.keys():
|
99
|
-
if key.endswith("in_proj_weight"):
|
100
|
-
weights = original_state_dict[key].chunk(3, 0)
|
101
|
-
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
102
|
-
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
103
|
-
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
104
|
-
elif key.endswith("in_proj_bias"):
|
105
|
-
weights = original_state_dict[key].chunk(3, 0)
|
106
|
-
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
107
|
-
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
108
|
-
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
109
|
-
elif key.endswith("out_proj.weight"):
|
110
|
-
weights = original_state_dict[key]
|
111
|
-
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
112
|
-
elif key.endswith("out_proj.bias"):
|
113
|
-
weights = original_state_dict[key]
|
114
|
-
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
115
|
-
else:
|
116
|
-
state_dict[key] = original_state_dict[key]
|
117
|
-
else:
|
118
|
-
state_dict = {}
|
119
|
-
for key in original_state_dict.keys():
|
120
|
-
if key.endswith("in_proj_weight"):
|
121
|
-
weights = original_state_dict[key].chunk(3, 0)
|
122
|
-
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
123
|
-
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
124
|
-
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
125
|
-
elif key.endswith("in_proj_bias"):
|
126
|
-
weights = original_state_dict[key].chunk(3, 0)
|
127
|
-
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
128
|
-
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
129
|
-
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
130
|
-
elif key.endswith("out_proj.weight"):
|
131
|
-
weights = original_state_dict[key]
|
132
|
-
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
133
|
-
elif key.endswith("out_proj.bias"):
|
134
|
-
weights = original_state_dict[key]
|
135
|
-
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
136
|
-
# rename clip_mapper to clip_txt_pooled_mapper
|
137
|
-
elif key.endswith("clip_mapper.weight"):
|
138
|
-
weights = original_state_dict[key]
|
139
|
-
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
140
|
-
elif key.endswith("clip_mapper.bias"):
|
141
|
-
weights = original_state_dict[key]
|
142
|
-
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
143
|
-
else:
|
144
|
-
state_dict[key] = original_state_dict[key]
|
145
|
-
|
146
|
-
return state_dict
|
147
|
-
|
148
|
-
|
149
|
-
def infer_stable_cascade_single_file_config(checkpoint):
|
150
|
-
is_stage_c = "clip_txt_mapper.weight" in checkpoint
|
151
|
-
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
|
152
|
-
|
153
|
-
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
|
154
|
-
config_type = "stage_c_lite"
|
155
|
-
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
|
156
|
-
config_type = "stage_c"
|
157
|
-
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
|
158
|
-
config_type = "stage_b_lite"
|
159
|
-
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
|
160
|
-
config_type = "stage_b"
|
161
|
-
|
162
|
-
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
|
163
|
-
|
164
|
-
|
165
116
|
DIFFUSERS_TO_LDM_MAPPING = {
|
166
117
|
"unet": {
|
167
118
|
"layers": {
|
@@ -255,14 +206,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
|
255
206
|
},
|
256
207
|
}
|
257
208
|
|
258
|
-
LDM_VAE_KEY = "first_stage_model."
|
259
|
-
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
260
|
-
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
261
|
-
LDM_UNET_KEY = "model.diffusion_model."
|
262
|
-
LDM_CONTROLNET_KEY = "control_model."
|
263
|
-
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
264
|
-
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
265
|
-
|
266
209
|
SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
|
267
210
|
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_bias",
|
268
211
|
"cond_stage_model.model.transformer.resblocks.23.attn.in_proj_weight",
|
@@ -279,11 +222,51 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
|
|
279
222
|
"cond_stage_model.model.text_projection",
|
280
223
|
]
|
281
224
|
|
225
|
+
# To support legacy scheduler_type argument
|
226
|
+
SCHEDULER_DEFAULT_CONFIG = {
|
227
|
+
"beta_schedule": "scaled_linear",
|
228
|
+
"beta_start": 0.00085,
|
229
|
+
"beta_end": 0.012,
|
230
|
+
"interpolation_type": "linear",
|
231
|
+
"num_train_timesteps": 1000,
|
232
|
+
"prediction_type": "epsilon",
|
233
|
+
"sample_max_value": 1.0,
|
234
|
+
"set_alpha_to_one": False,
|
235
|
+
"skip_prk_steps": True,
|
236
|
+
"steps_offset": 1,
|
237
|
+
"timestep_spacing": "leading",
|
238
|
+
}
|
239
|
+
|
240
|
+
LDM_VAE_KEY = "first_stage_model."
|
241
|
+
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
242
|
+
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
243
|
+
LDM_UNET_KEY = "model.diffusion_model."
|
244
|
+
LDM_CONTROLNET_KEY = "control_model."
|
245
|
+
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
246
|
+
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
247
|
+
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
282
248
|
|
283
249
|
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
284
250
|
|
285
251
|
|
252
|
+
class SingleFileComponentError(Exception):
|
253
|
+
def __init__(self, message=None):
|
254
|
+
self.message = message
|
255
|
+
super().__init__(self.message)
|
256
|
+
|
257
|
+
|
258
|
+
def is_valid_url(url):
|
259
|
+
result = urlparse(url)
|
260
|
+
if result.scheme and result.netloc:
|
261
|
+
return True
|
262
|
+
|
263
|
+
return False
|
264
|
+
|
265
|
+
|
286
266
|
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
267
|
+
if not is_valid_url(pretrained_model_name_or_path):
|
268
|
+
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
|
269
|
+
|
287
270
|
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
|
288
271
|
weights_name = None
|
289
272
|
repo_id = (None,)
|
@@ -291,6 +274,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
|
291
274
|
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
|
292
275
|
match = re.match(pattern, pretrained_model_name_or_path)
|
293
276
|
if not match:
|
277
|
+
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
|
294
278
|
return repo_id, weights_name
|
295
279
|
|
296
280
|
repo_id = f"{match.group(1)}/{match.group(2)}"
|
@@ -299,34 +283,18 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
|
299
283
|
return repo_id, weights_name
|
300
284
|
|
301
285
|
|
302
|
-
def
|
303
|
-
|
304
|
-
|
305
|
-
original_config_file=None,
|
306
|
-
resume_download=False,
|
307
|
-
force_download=False,
|
308
|
-
proxies=None,
|
309
|
-
token=None,
|
310
|
-
cache_dir=None,
|
311
|
-
local_files_only=None,
|
312
|
-
revision=None,
|
313
|
-
):
|
314
|
-
checkpoint = load_single_file_model_checkpoint(
|
315
|
-
pretrained_model_link_or_path,
|
316
|
-
resume_download=resume_download,
|
317
|
-
force_download=force_download,
|
318
|
-
proxies=proxies,
|
319
|
-
token=token,
|
320
|
-
cache_dir=cache_dir,
|
321
|
-
local_files_only=local_files_only,
|
322
|
-
revision=revision,
|
323
|
-
)
|
324
|
-
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
286
|
+
def _is_model_weights_in_cached_folder(cached_folder, name):
|
287
|
+
pretrained_model_name_or_path = os.path.join(cached_folder, name)
|
288
|
+
weights_exist = False
|
325
289
|
|
326
|
-
|
290
|
+
for weights_name in [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME]:
|
291
|
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
292
|
+
weights_exist = True
|
327
293
|
|
294
|
+
return weights_exist
|
328
295
|
|
329
|
-
|
296
|
+
|
297
|
+
def load_single_file_checkpoint(
|
330
298
|
pretrained_model_link_or_path,
|
331
299
|
resume_download=False,
|
332
300
|
force_download=False,
|
@@ -337,10 +305,11 @@ def load_single_file_model_checkpoint(
|
|
337
305
|
revision=None,
|
338
306
|
):
|
339
307
|
if os.path.isfile(pretrained_model_link_or_path):
|
340
|
-
|
308
|
+
pretrained_model_link_or_path = pretrained_model_link_or_path
|
309
|
+
|
341
310
|
else:
|
342
311
|
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
343
|
-
|
312
|
+
pretrained_model_link_or_path = _get_model_file(
|
344
313
|
repo_id,
|
345
314
|
weights_name=weights_name,
|
346
315
|
force_download=force_download,
|
@@ -351,7 +320,8 @@ def load_single_file_model_checkpoint(
|
|
351
320
|
token=token,
|
352
321
|
revision=revision,
|
353
322
|
)
|
354
|
-
|
323
|
+
|
324
|
+
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
355
325
|
|
356
326
|
# some checkpoints contain the model state dict under a "state_dict" key
|
357
327
|
while "state_dict" in checkpoint:
|
@@ -360,120 +330,154 @@ def load_single_file_model_checkpoint(
|
|
360
330
|
return checkpoint
|
361
331
|
|
362
332
|
|
363
|
-
def
|
364
|
-
if
|
365
|
-
|
333
|
+
def fetch_original_config(original_config_file, local_files_only=False):
|
334
|
+
if os.path.isfile(original_config_file):
|
335
|
+
with open(original_config_file, "r") as fp:
|
336
|
+
original_config_file = fp.read()
|
366
337
|
|
367
|
-
elif
|
368
|
-
|
338
|
+
elif is_valid_url(original_config_file):
|
339
|
+
if local_files_only:
|
340
|
+
raise ValueError(
|
341
|
+
"`local_files_only` is set to True, but a URL was provided as `original_config_file`. "
|
342
|
+
"Please provide a valid local file path."
|
343
|
+
)
|
369
344
|
|
370
|
-
|
371
|
-
config_url = CONFIG_URLS["xl_refiner"]
|
345
|
+
original_config_file = BytesIO(requests.get(original_config_file).content)
|
372
346
|
|
373
|
-
|
374
|
-
|
347
|
+
else:
|
348
|
+
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
|
375
349
|
|
376
|
-
|
377
|
-
config_url = CONFIG_URLS["controlnet"]
|
350
|
+
original_config = yaml.safe_load(original_config_file)
|
378
351
|
|
379
|
-
|
380
|
-
config_url = CONFIG_URLS["v1"]
|
352
|
+
return original_config
|
381
353
|
|
382
|
-
original_config_file = BytesIO(requests.get(config_url).content)
|
383
354
|
|
384
|
-
|
355
|
+
def is_clip_model(checkpoint):
|
356
|
+
if CHECKPOINT_KEY_NAMES["clip"] in checkpoint:
|
357
|
+
return True
|
385
358
|
|
359
|
+
return False
|
386
360
|
|
387
|
-
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None):
|
388
|
-
def is_valid_url(url):
|
389
|
-
result = urlparse(url)
|
390
|
-
if result.scheme and result.netloc:
|
391
|
-
return True
|
392
361
|
|
393
|
-
|
362
|
+
def is_clip_sdxl_model(checkpoint):
|
363
|
+
if CHECKPOINT_KEY_NAMES["clip_sdxl"] in checkpoint:
|
364
|
+
return True
|
394
365
|
|
395
|
-
|
396
|
-
original_config_file = infer_original_config_file(pipeline_class_name, checkpoint)
|
366
|
+
return False
|
397
367
|
|
398
|
-
elif os.path.isfile(original_config_file):
|
399
|
-
with open(original_config_file, "r") as fp:
|
400
|
-
original_config_file = fp.read()
|
401
368
|
|
402
|
-
|
403
|
-
|
369
|
+
def is_open_clip_model(checkpoint):
|
370
|
+
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
|
371
|
+
return True
|
404
372
|
|
405
|
-
|
406
|
-
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
|
373
|
+
return False
|
407
374
|
|
408
|
-
original_config = yaml.safe_load(original_config_file)
|
409
375
|
|
410
|
-
|
376
|
+
def is_open_clip_sdxl_model(checkpoint):
|
377
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sdxl"] in checkpoint:
|
378
|
+
return True
|
411
379
|
|
380
|
+
return False
|
412
381
|
|
413
|
-
def infer_model_type(original_config, checkpoint, model_type=None):
|
414
|
-
if model_type is not None:
|
415
|
-
return model_type
|
416
382
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
383
|
+
def is_open_clip_sdxl_refiner_model(checkpoint):
|
384
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
385
|
+
return True
|
386
|
+
|
387
|
+
return False
|
388
|
+
|
389
|
+
|
390
|
+
def is_clip_model_in_single_file(class_obj, checkpoint):
|
391
|
+
is_clip_in_checkpoint = any(
|
392
|
+
[
|
393
|
+
is_clip_model(checkpoint),
|
394
|
+
is_open_clip_model(checkpoint),
|
395
|
+
is_open_clip_sdxl_model(checkpoint),
|
396
|
+
is_open_clip_sdxl_refiner_model(checkpoint),
|
397
|
+
]
|
424
398
|
)
|
399
|
+
if (
|
400
|
+
class_obj.__name__ == "CLIPTextModel" or class_obj.__name__ == "CLIPTextModelWithProjection"
|
401
|
+
) and is_clip_in_checkpoint:
|
402
|
+
return True
|
403
|
+
|
404
|
+
return False
|
425
405
|
|
426
|
-
if has_cond_stage_config:
|
427
|
-
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
|
428
406
|
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
407
|
+
def infer_diffusers_model_type(checkpoint):
|
408
|
+
if (
|
409
|
+
CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
|
410
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
|
411
|
+
):
|
412
|
+
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
|
413
|
+
model_type = "inpainting_v2"
|
435
414
|
else:
|
436
|
-
model_type = "
|
437
|
-
else:
|
438
|
-
raise ValueError("Unable to infer model type from config")
|
415
|
+
model_type = "inpainting"
|
439
416
|
|
440
|
-
|
417
|
+
elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
|
418
|
+
model_type = "v2"
|
441
419
|
|
442
|
-
|
420
|
+
elif CHECKPOINT_KEY_NAMES["playground-v2-5"] in checkpoint:
|
421
|
+
model_type = "playground-v2-5"
|
443
422
|
|
423
|
+
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
|
424
|
+
model_type = "xl_base"
|
444
425
|
|
445
|
-
|
446
|
-
|
426
|
+
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
|
427
|
+
model_type = "xl_refiner"
|
447
428
|
|
429
|
+
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
|
430
|
+
model_type = "upscale"
|
448
431
|
|
449
|
-
|
450
|
-
|
451
|
-
return image_size
|
432
|
+
elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
|
433
|
+
model_type = "controlnet"
|
452
434
|
|
453
|
-
|
454
|
-
|
435
|
+
elif (
|
436
|
+
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
|
437
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 1536
|
438
|
+
):
|
439
|
+
model_type = "stable_cascade_stage_c_lite"
|
455
440
|
|
456
|
-
|
457
|
-
|
458
|
-
|
441
|
+
elif (
|
442
|
+
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
|
443
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"]].shape[0] == 2048
|
444
|
+
):
|
445
|
+
model_type = "stable_cascade_stage_c"
|
459
446
|
|
460
|
-
elif
|
461
|
-
|
462
|
-
|
447
|
+
elif (
|
448
|
+
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
|
449
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 576
|
450
|
+
):
|
451
|
+
model_type = "stable_cascade_stage_b_lite"
|
463
452
|
|
464
453
|
elif (
|
465
|
-
"
|
466
|
-
and
|
454
|
+
CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"] in checkpoint
|
455
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["stable_cascade_stage_b"]].shape[-1] == 640
|
467
456
|
):
|
468
|
-
|
469
|
-
# as it relies on a brittle global step parameter here
|
470
|
-
image_size = 512 if global_step == 875000 else 768
|
471
|
-
return image_size
|
457
|
+
model_type = "stable_cascade_stage_b"
|
472
458
|
|
473
459
|
else:
|
474
|
-
|
460
|
+
model_type = "v1"
|
461
|
+
|
462
|
+
return model_type
|
463
|
+
|
464
|
+
|
465
|
+
def fetch_diffusers_config(checkpoint):
|
466
|
+
model_type = infer_diffusers_model_type(checkpoint)
|
467
|
+
model_path = DIFFUSERS_DEFAULT_PIPELINE_PATHS[model_type]
|
468
|
+
|
469
|
+
return model_path
|
470
|
+
|
471
|
+
|
472
|
+
def set_image_size(checkpoint, image_size=None):
|
473
|
+
if image_size:
|
475
474
|
return image_size
|
476
475
|
|
476
|
+
model_type = infer_diffusers_model_type(checkpoint)
|
477
|
+
image_size = DIFFUSERS_TO_LDM_DEFAULT_IMAGE_SIZE_MAP[model_type]
|
478
|
+
|
479
|
+
return image_size
|
480
|
+
|
477
481
|
|
478
482
|
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
479
483
|
def conv_attn_to_linear(checkpoint):
|
@@ -488,10 +492,21 @@ def conv_attn_to_linear(checkpoint):
|
|
488
492
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
489
493
|
|
490
494
|
|
491
|
-
def
|
495
|
+
def create_unet_diffusers_config_from_ldm(
|
496
|
+
original_config, checkpoint, image_size=None, upcast_attention=None, num_in_channels=None
|
497
|
+
):
|
492
498
|
"""
|
493
499
|
Creates a config for the diffusers based on the config of the LDM model.
|
494
500
|
"""
|
501
|
+
if image_size is not None:
|
502
|
+
deprecation_message = (
|
503
|
+
"Configuring UNet2DConditionModel with the `image_size` argument to `from_single_file`"
|
504
|
+
"is deprecated and will be ignored in future versions."
|
505
|
+
)
|
506
|
+
deprecate("image_size", "1.0.0", deprecation_message)
|
507
|
+
|
508
|
+
image_size = set_image_size(checkpoint, image_size=image_size)
|
509
|
+
|
495
510
|
if (
|
496
511
|
"unet_config" in original_config["model"]["params"]
|
497
512
|
and original_config["model"]["params"]["unet_config"] is not None
|
@@ -500,6 +515,16 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
500
515
|
else:
|
501
516
|
unet_params = original_config["model"]["params"]["network_config"]["params"]
|
502
517
|
|
518
|
+
if num_in_channels is not None:
|
519
|
+
deprecation_message = (
|
520
|
+
"Configuring UNet2DConditionModel with the `num_in_channels` argument to `from_single_file`"
|
521
|
+
"is deprecated and will be ignored in future versions."
|
522
|
+
)
|
523
|
+
deprecate("image_size", "1.0.0", deprecation_message)
|
524
|
+
in_channels = num_in_channels
|
525
|
+
else:
|
526
|
+
in_channels = unet_params["in_channels"]
|
527
|
+
|
503
528
|
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
504
529
|
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
505
530
|
|
@@ -564,7 +589,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
564
589
|
|
565
590
|
config = {
|
566
591
|
"sample_size": image_size // vae_scale_factor,
|
567
|
-
"in_channels":
|
592
|
+
"in_channels": in_channels,
|
568
593
|
"down_block_types": down_block_types,
|
569
594
|
"block_out_channels": block_out_channels,
|
570
595
|
"layers_per_block": unet_params["num_res_blocks"],
|
@@ -578,6 +603,14 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
578
603
|
"transformer_layers_per_block": transformer_layers_per_block,
|
579
604
|
}
|
580
605
|
|
606
|
+
if upcast_attention is not None:
|
607
|
+
deprecation_message = (
|
608
|
+
"Configuring UNet2DConditionModel with the `upcast_attention` argument to `from_single_file`"
|
609
|
+
"is deprecated and will be ignored in future versions."
|
610
|
+
)
|
611
|
+
deprecate("image_size", "1.0.0", deprecation_message)
|
612
|
+
config["upcast_attention"] = upcast_attention
|
613
|
+
|
581
614
|
if "disable_self_attentions" in unet_params:
|
582
615
|
config["only_cross_attention"] = unet_params["disable_self_attentions"]
|
583
616
|
|
@@ -590,9 +623,18 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
590
623
|
return config
|
591
624
|
|
592
625
|
|
593
|
-
def
|
626
|
+
def create_controlnet_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, **kwargs):
|
627
|
+
if image_size is not None:
|
628
|
+
deprecation_message = (
|
629
|
+
"Configuring ControlNetModel with the `image_size` argument"
|
630
|
+
"is deprecated and will be ignored in future versions."
|
631
|
+
)
|
632
|
+
deprecate("image_size", "1.0.0", deprecation_message)
|
633
|
+
|
634
|
+
image_size = set_image_size(checkpoint, image_size=image_size)
|
635
|
+
|
594
636
|
unet_params = original_config["model"]["params"]["control_stage_config"]["params"]
|
595
|
-
diffusers_unet_config =
|
637
|
+
diffusers_unet_config = create_unet_diffusers_config_from_ldm(original_config, image_size=image_size)
|
596
638
|
|
597
639
|
controlnet_config = {
|
598
640
|
"conditioning_channels": unet_params["hint_channels"],
|
@@ -613,15 +655,33 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
|
|
613
655
|
return controlnet_config
|
614
656
|
|
615
657
|
|
616
|
-
def
|
658
|
+
def create_vae_diffusers_config_from_ldm(original_config, checkpoint, image_size=None, scaling_factor=None):
|
617
659
|
"""
|
618
660
|
Creates a config for the diffusers based on the config of the LDM model.
|
619
661
|
"""
|
662
|
+
if image_size is not None:
|
663
|
+
deprecation_message = (
|
664
|
+
"Configuring AutoencoderKL with the `image_size` argument"
|
665
|
+
"is deprecated and will be ignored in future versions."
|
666
|
+
)
|
667
|
+
deprecate("image_size", "1.0.0", deprecation_message)
|
668
|
+
|
669
|
+
image_size = set_image_size(checkpoint, image_size=image_size)
|
670
|
+
|
671
|
+
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
|
672
|
+
latents_mean = checkpoint["edm_mean"]
|
673
|
+
latents_std = checkpoint["edm_std"]
|
674
|
+
else:
|
675
|
+
latents_mean = None
|
676
|
+
latents_std = None
|
677
|
+
|
620
678
|
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
621
679
|
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
|
622
680
|
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
|
681
|
+
|
623
682
|
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
|
624
683
|
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
684
|
+
|
625
685
|
elif scaling_factor is None:
|
626
686
|
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
627
687
|
|
@@ -658,16 +718,104 @@ def update_unet_resnet_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, ma
|
|
658
718
|
)
|
659
719
|
if mapping:
|
660
720
|
diffusers_key = diffusers_key.replace(mapping["old"], mapping["new"])
|
661
|
-
new_checkpoint[diffusers_key] = checkpoint.
|
721
|
+
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
662
722
|
|
663
723
|
|
664
724
|
def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, mapping):
|
665
725
|
for ldm_key in ldm_keys:
|
666
726
|
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"])
|
667
|
-
new_checkpoint[diffusers_key] = checkpoint.
|
727
|
+
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
728
|
+
|
668
729
|
|
730
|
+
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
731
|
+
for ldm_key in keys:
|
732
|
+
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
733
|
+
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
734
|
+
|
735
|
+
|
736
|
+
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
737
|
+
for ldm_key in keys:
|
738
|
+
diffusers_key = (
|
739
|
+
ldm_key.replace(mapping["old"], mapping["new"])
|
740
|
+
.replace("norm.weight", "group_norm.weight")
|
741
|
+
.replace("norm.bias", "group_norm.bias")
|
742
|
+
.replace("q.weight", "to_q.weight")
|
743
|
+
.replace("q.bias", "to_q.bias")
|
744
|
+
.replace("k.weight", "to_k.weight")
|
745
|
+
.replace("k.bias", "to_k.bias")
|
746
|
+
.replace("v.weight", "to_v.weight")
|
747
|
+
.replace("v.bias", "to_v.bias")
|
748
|
+
.replace("proj_out.weight", "to_out.0.weight")
|
749
|
+
.replace("proj_out.bias", "to_out.0.bias")
|
750
|
+
)
|
751
|
+
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
752
|
+
|
753
|
+
# proj_attn.weight has to be converted from conv 1D to linear
|
754
|
+
shape = new_checkpoint[diffusers_key].shape
|
669
755
|
|
670
|
-
|
756
|
+
if len(shape) == 3:
|
757
|
+
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
|
758
|
+
elif len(shape) == 4:
|
759
|
+
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
|
760
|
+
|
761
|
+
|
762
|
+
def convert_stable_cascade_unet_single_file_to_diffusers(checkpoint, **kwargs):
|
763
|
+
is_stage_c = "clip_txt_mapper.weight" in checkpoint
|
764
|
+
|
765
|
+
if is_stage_c:
|
766
|
+
state_dict = {}
|
767
|
+
for key in checkpoint.keys():
|
768
|
+
if key.endswith("in_proj_weight"):
|
769
|
+
weights = checkpoint[key].chunk(3, 0)
|
770
|
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
771
|
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
772
|
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
773
|
+
elif key.endswith("in_proj_bias"):
|
774
|
+
weights = checkpoint[key].chunk(3, 0)
|
775
|
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
776
|
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
777
|
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
778
|
+
elif key.endswith("out_proj.weight"):
|
779
|
+
weights = checkpoint[key]
|
780
|
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
781
|
+
elif key.endswith("out_proj.bias"):
|
782
|
+
weights = checkpoint[key]
|
783
|
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
784
|
+
else:
|
785
|
+
state_dict[key] = checkpoint[key]
|
786
|
+
else:
|
787
|
+
state_dict = {}
|
788
|
+
for key in checkpoint.keys():
|
789
|
+
if key.endswith("in_proj_weight"):
|
790
|
+
weights = checkpoint[key].chunk(3, 0)
|
791
|
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
792
|
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
793
|
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
794
|
+
elif key.endswith("in_proj_bias"):
|
795
|
+
weights = checkpoint[key].chunk(3, 0)
|
796
|
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
797
|
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
798
|
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
799
|
+
elif key.endswith("out_proj.weight"):
|
800
|
+
weights = checkpoint[key]
|
801
|
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
802
|
+
elif key.endswith("out_proj.bias"):
|
803
|
+
weights = checkpoint[key]
|
804
|
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
805
|
+
# rename clip_mapper to clip_txt_pooled_mapper
|
806
|
+
elif key.endswith("clip_mapper.weight"):
|
807
|
+
weights = checkpoint[key]
|
808
|
+
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
809
|
+
elif key.endswith("clip_mapper.bias"):
|
810
|
+
weights = checkpoint[key]
|
811
|
+
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
812
|
+
else:
|
813
|
+
state_dict[key] = checkpoint[key]
|
814
|
+
|
815
|
+
return state_dict
|
816
|
+
|
817
|
+
|
818
|
+
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False, **kwargs):
|
671
819
|
"""
|
672
820
|
Takes a state dict and a config, and returns a converted checkpoint.
|
673
821
|
"""
|
@@ -686,7 +834,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
|
686
834
|
for key in keys:
|
687
835
|
if key.startswith("model.diffusion_model"):
|
688
836
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
689
|
-
unet_state_dict[key.replace(unet_key, "")] = checkpoint.
|
837
|
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(flat_ema_key)
|
690
838
|
else:
|
691
839
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
692
840
|
logger.warning(
|
@@ -695,7 +843,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
|
695
843
|
)
|
696
844
|
for key in keys:
|
697
845
|
if key.startswith(unet_key):
|
698
|
-
unet_state_dict[key.replace(unet_key, "")] = checkpoint.
|
846
|
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.get(key)
|
699
847
|
|
700
848
|
new_checkpoint = {}
|
701
849
|
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
|
@@ -756,10 +904,10 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
|
756
904
|
)
|
757
905
|
|
758
906
|
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
759
|
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.
|
907
|
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.get(
|
760
908
|
f"input_blocks.{i}.0.op.weight"
|
761
909
|
)
|
762
|
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.
|
910
|
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.get(
|
763
911
|
f"input_blocks.{i}.0.op.bias"
|
764
912
|
)
|
765
913
|
|
@@ -773,19 +921,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
|
773
921
|
)
|
774
922
|
|
775
923
|
# Mid blocks
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
924
|
+
for key in middle_blocks.keys():
|
925
|
+
diffusers_key = max(key - 1, 0)
|
926
|
+
if key % 2 == 0:
|
927
|
+
update_unet_resnet_ldm_to_diffusers(
|
928
|
+
middle_blocks[key],
|
929
|
+
new_checkpoint,
|
930
|
+
unet_state_dict,
|
931
|
+
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
|
932
|
+
)
|
933
|
+
else:
|
934
|
+
update_unet_attention_ldm_to_diffusers(
|
935
|
+
middle_blocks[key],
|
936
|
+
new_checkpoint,
|
937
|
+
unet_state_dict,
|
938
|
+
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
|
939
|
+
)
|
789
940
|
|
790
941
|
# Up Blocks
|
791
942
|
for i in range(num_output_blocks):
|
@@ -834,6 +985,7 @@ def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
|
834
985
|
def convert_controlnet_checkpoint(
|
835
986
|
checkpoint,
|
836
987
|
config,
|
988
|
+
**kwargs,
|
837
989
|
):
|
838
990
|
# Some controlnet ckpt files are distributed independently from the rest of the
|
839
991
|
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
@@ -846,7 +998,7 @@ def convert_controlnet_checkpoint(
|
|
846
998
|
controlnet_key = LDM_CONTROLNET_KEY
|
847
999
|
for key in keys:
|
848
1000
|
if key.startswith(controlnet_key):
|
849
|
-
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.
|
1001
|
+
controlnet_state_dict[key.replace(controlnet_key, "")] = checkpoint.get(key)
|
850
1002
|
|
851
1003
|
new_checkpoint = {}
|
852
1004
|
ldm_controlnet_keys = DIFFUSERS_TO_LDM_MAPPING["controlnet"]["layers"]
|
@@ -880,10 +1032,10 @@ def convert_controlnet_checkpoint(
|
|
880
1032
|
)
|
881
1033
|
|
882
1034
|
if f"input_blocks.{i}.0.op.weight" in controlnet_state_dict:
|
883
|
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.
|
1035
|
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = controlnet_state_dict.get(
|
884
1036
|
f"input_blocks.{i}.0.op.weight"
|
885
1037
|
)
|
886
|
-
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.
|
1038
|
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = controlnet_state_dict.get(
|
887
1039
|
f"input_blocks.{i}.0.op.bias"
|
888
1040
|
)
|
889
1041
|
|
@@ -898,8 +1050,8 @@ def convert_controlnet_checkpoint(
|
|
898
1050
|
|
899
1051
|
# controlnet down blocks
|
900
1052
|
for i in range(num_input_blocks):
|
901
|
-
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.
|
902
|
-
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.
|
1053
|
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = controlnet_state_dict.get(f"zero_convs.{i}.0.weight")
|
1054
|
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = controlnet_state_dict.get(f"zero_convs.{i}.0.bias")
|
903
1055
|
|
904
1056
|
# Retrieves the keys for the middle blocks only
|
905
1057
|
num_middle_blocks = len(
|
@@ -909,33 +1061,28 @@ def convert_controlnet_checkpoint(
|
|
909
1061
|
layer_id: [key for key in controlnet_state_dict if f"middle_block.{layer_id}" in key]
|
910
1062
|
for layer_id in range(num_middle_blocks)
|
911
1063
|
}
|
912
|
-
if middle_blocks:
|
913
|
-
resnet_0 = middle_blocks[0]
|
914
|
-
attentions = middle_blocks[1]
|
915
|
-
resnet_1 = middle_blocks[2]
|
916
1064
|
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
)
|
1065
|
+
# Mid blocks
|
1066
|
+
for key in middle_blocks.keys():
|
1067
|
+
diffusers_key = max(key - 1, 0)
|
1068
|
+
if key % 2 == 0:
|
1069
|
+
update_unet_resnet_ldm_to_diffusers(
|
1070
|
+
middle_blocks[key],
|
1071
|
+
new_checkpoint,
|
1072
|
+
controlnet_state_dict,
|
1073
|
+
mapping={"old": f"middle_block.{key}", "new": f"mid_block.resnets.{diffusers_key}"},
|
1074
|
+
)
|
1075
|
+
else:
|
1076
|
+
update_unet_attention_ldm_to_diffusers(
|
1077
|
+
middle_blocks[key],
|
1078
|
+
new_checkpoint,
|
1079
|
+
controlnet_state_dict,
|
1080
|
+
mapping={"old": f"middle_block.{key}", "new": f"mid_block.attentions.{diffusers_key}"},
|
1081
|
+
)
|
935
1082
|
|
936
1083
|
# mid block
|
937
|
-
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.
|
938
|
-
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.
|
1084
|
+
new_checkpoint["controlnet_mid_block.weight"] = controlnet_state_dict.get("middle_block_out.0.weight")
|
1085
|
+
new_checkpoint["controlnet_mid_block.bias"] = controlnet_state_dict.get("middle_block_out.0.bias")
|
939
1086
|
|
940
1087
|
# controlnet cond embedding blocks
|
941
1088
|
cond_embedding_blocks = {
|
@@ -949,88 +1096,16 @@ def convert_controlnet_checkpoint(
|
|
949
1096
|
diffusers_idx = idx - 1
|
950
1097
|
cond_block_id = 2 * idx
|
951
1098
|
|
952
|
-
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.
|
1099
|
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] = controlnet_state_dict.get(
|
953
1100
|
f"input_hint_block.{cond_block_id}.weight"
|
954
1101
|
)
|
955
|
-
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.
|
1102
|
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] = controlnet_state_dict.get(
|
956
1103
|
f"input_hint_block.{cond_block_id}.bias"
|
957
1104
|
)
|
958
1105
|
|
959
1106
|
return new_checkpoint
|
960
1107
|
|
961
1108
|
|
962
|
-
def create_diffusers_controlnet_model_from_ldm(
|
963
|
-
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
|
964
|
-
):
|
965
|
-
# import here to avoid circular imports
|
966
|
-
from ..models import ControlNetModel
|
967
|
-
|
968
|
-
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
969
|
-
|
970
|
-
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
|
971
|
-
diffusers_config["upcast_attention"] = upcast_attention
|
972
|
-
|
973
|
-
diffusers_format_controlnet_checkpoint = convert_controlnet_checkpoint(checkpoint, diffusers_config)
|
974
|
-
|
975
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
976
|
-
with ctx():
|
977
|
-
controlnet = ControlNetModel(**diffusers_config)
|
978
|
-
|
979
|
-
if is_accelerate_available():
|
980
|
-
from ..models.modeling_utils import load_model_dict_into_meta
|
981
|
-
|
982
|
-
unexpected_keys = load_model_dict_into_meta(
|
983
|
-
controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
|
984
|
-
)
|
985
|
-
if controlnet._keys_to_ignore_on_load_unexpected is not None:
|
986
|
-
for pat in controlnet._keys_to_ignore_on_load_unexpected:
|
987
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
988
|
-
|
989
|
-
if len(unexpected_keys) > 0:
|
990
|
-
logger.warning(
|
991
|
-
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
992
|
-
)
|
993
|
-
else:
|
994
|
-
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
|
995
|
-
|
996
|
-
if torch_dtype is not None:
|
997
|
-
controlnet = controlnet.to(torch_dtype)
|
998
|
-
|
999
|
-
return {"controlnet": controlnet}
|
1000
|
-
|
1001
|
-
|
1002
|
-
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
1003
|
-
for ldm_key in keys:
|
1004
|
-
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
1005
|
-
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
|
1006
|
-
|
1007
|
-
|
1008
|
-
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
1009
|
-
for ldm_key in keys:
|
1010
|
-
diffusers_key = (
|
1011
|
-
ldm_key.replace(mapping["old"], mapping["new"])
|
1012
|
-
.replace("norm.weight", "group_norm.weight")
|
1013
|
-
.replace("norm.bias", "group_norm.bias")
|
1014
|
-
.replace("q.weight", "to_q.weight")
|
1015
|
-
.replace("q.bias", "to_q.bias")
|
1016
|
-
.replace("k.weight", "to_k.weight")
|
1017
|
-
.replace("k.bias", "to_k.bias")
|
1018
|
-
.replace("v.weight", "to_v.weight")
|
1019
|
-
.replace("v.bias", "to_v.bias")
|
1020
|
-
.replace("proj_out.weight", "to_out.0.weight")
|
1021
|
-
.replace("proj_out.bias", "to_out.0.bias")
|
1022
|
-
)
|
1023
|
-
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
|
1024
|
-
|
1025
|
-
# proj_attn.weight has to be converted from conv 1D to linear
|
1026
|
-
shape = new_checkpoint[diffusers_key].shape
|
1027
|
-
|
1028
|
-
if len(shape) == 3:
|
1029
|
-
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
|
1030
|
-
elif len(shape) == 4:
|
1031
|
-
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
|
1032
|
-
|
1033
|
-
|
1034
1109
|
def convert_ldm_vae_checkpoint(checkpoint, config):
|
1035
1110
|
# extract state dict for VAE
|
1036
1111
|
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
@@ -1063,10 +1138,10 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1063
1138
|
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
|
1064
1139
|
)
|
1065
1140
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
1066
|
-
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.
|
1141
|
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
|
1067
1142
|
f"encoder.down.{i}.downsample.conv.weight"
|
1068
1143
|
)
|
1069
|
-
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.
|
1144
|
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
|
1070
1145
|
f"encoder.down.{i}.downsample.conv.bias"
|
1071
1146
|
)
|
1072
1147
|
|
@@ -1131,18 +1206,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1131
1206
|
return new_checkpoint
|
1132
1207
|
|
1133
1208
|
|
1134
|
-
def
|
1135
|
-
try:
|
1136
|
-
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
1137
|
-
except Exception:
|
1138
|
-
raise ValueError(
|
1139
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
|
1140
|
-
)
|
1141
|
-
|
1142
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1143
|
-
with ctx():
|
1144
|
-
text_model = CLIPTextModel(config)
|
1145
|
-
|
1209
|
+
def convert_ldm_clip_checkpoint(checkpoint):
|
1146
1210
|
keys = list(checkpoint.keys())
|
1147
1211
|
text_model_dict = {}
|
1148
1212
|
|
@@ -1152,57 +1216,26 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
|
1152
1216
|
for prefix in remove_prefixes:
|
1153
1217
|
if key.startswith(prefix):
|
1154
1218
|
diffusers_key = key.replace(prefix, "")
|
1155
|
-
text_model_dict[diffusers_key] = checkpoint
|
1156
|
-
|
1157
|
-
if is_accelerate_available():
|
1158
|
-
from ..models.modeling_utils import load_model_dict_into_meta
|
1159
|
-
|
1160
|
-
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
1161
|
-
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
1162
|
-
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
1163
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1219
|
+
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1164
1220
|
|
1165
|
-
|
1166
|
-
logger.warning(
|
1167
|
-
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1168
|
-
)
|
1169
|
-
else:
|
1170
|
-
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
1171
|
-
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1221
|
+
return text_model_dict
|
1172
1222
|
|
1173
|
-
text_model.load_state_dict(text_model_dict)
|
1174
1223
|
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
return text_model
|
1179
|
-
|
1180
|
-
|
1181
|
-
def create_text_encoder_from_open_clip_checkpoint(
|
1182
|
-
config_name,
|
1224
|
+
def convert_open_clip_checkpoint(
|
1225
|
+
text_model,
|
1183
1226
|
checkpoint,
|
1184
1227
|
prefix="cond_stage_model.model.",
|
1185
|
-
has_projection=False,
|
1186
|
-
local_files_only=False,
|
1187
|
-
torch_dtype=None,
|
1188
|
-
**config_kwargs,
|
1189
1228
|
):
|
1190
|
-
try:
|
1191
|
-
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
|
1192
|
-
except Exception:
|
1193
|
-
raise ValueError(
|
1194
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
|
1195
|
-
)
|
1196
|
-
|
1197
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1198
|
-
with ctx():
|
1199
|
-
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
|
1200
|
-
|
1201
1229
|
text_model_dict = {}
|
1202
1230
|
text_proj_key = prefix + "text_projection"
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1231
|
+
|
1232
|
+
if text_proj_key in checkpoint:
|
1233
|
+
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
|
1234
|
+
elif hasattr(text_model.config, "projection_dim"):
|
1235
|
+
text_proj_dim = text_model.config.projection_dim
|
1236
|
+
else:
|
1237
|
+
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
1238
|
+
|
1206
1239
|
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
1207
1240
|
|
1208
1241
|
keys = list(checkpoint.keys())
|
@@ -1235,309 +1268,165 @@ def create_text_encoder_from_open_clip_checkpoint(
|
|
1235
1268
|
)
|
1236
1269
|
|
1237
1270
|
if key.endswith(".in_proj_weight"):
|
1238
|
-
weight_value = checkpoint
|
1271
|
+
weight_value = checkpoint.get(key)
|
1239
1272
|
|
1240
|
-
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :]
|
1241
|
-
text_model_dict[diffusers_key + ".k_proj.weight"] =
|
1242
|
-
|
1273
|
+
text_model_dict[diffusers_key + ".q_proj.weight"] = weight_value[:text_proj_dim, :].clone().detach()
|
1274
|
+
text_model_dict[diffusers_key + ".k_proj.weight"] = (
|
1275
|
+
weight_value[text_proj_dim : text_proj_dim * 2, :].clone().detach()
|
1276
|
+
)
|
1277
|
+
text_model_dict[diffusers_key + ".v_proj.weight"] = weight_value[text_proj_dim * 2 :, :].clone().detach()
|
1243
1278
|
|
1244
1279
|
elif key.endswith(".in_proj_bias"):
|
1245
|
-
weight_value = checkpoint
|
1246
|
-
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
|
1247
|
-
text_model_dict[diffusers_key + ".k_proj.bias"] =
|
1248
|
-
|
1249
|
-
else:
|
1250
|
-
text_model_dict[diffusers_key] = checkpoint[key]
|
1251
|
-
|
1252
|
-
if is_accelerate_available():
|
1253
|
-
from ..models.modeling_utils import load_model_dict_into_meta
|
1254
|
-
|
1255
|
-
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
1256
|
-
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
1257
|
-
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
1258
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1259
|
-
|
1260
|
-
if len(unexpected_keys) > 0:
|
1261
|
-
logger.warning(
|
1262
|
-
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1280
|
+
weight_value = checkpoint.get(key)
|
1281
|
+
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim].clone().detach()
|
1282
|
+
text_model_dict[diffusers_key + ".k_proj.bias"] = (
|
1283
|
+
weight_value[text_proj_dim : text_proj_dim * 2].clone().detach()
|
1263
1284
|
)
|
1285
|
+
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :].clone().detach()
|
1286
|
+
else:
|
1287
|
+
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1264
1288
|
|
1265
|
-
|
1266
|
-
|
1267
|
-
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1268
|
-
|
1269
|
-
text_model.load_state_dict(text_model_dict)
|
1270
|
-
|
1271
|
-
if torch_dtype is not None:
|
1272
|
-
text_model = text_model.to(torch_dtype)
|
1289
|
+
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
1290
|
+
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1273
1291
|
|
1274
|
-
return
|
1292
|
+
return text_model_dict
|
1275
1293
|
|
1276
1294
|
|
1277
|
-
def
|
1278
|
-
|
1279
|
-
original_config,
|
1295
|
+
def create_diffusers_clip_model_from_ldm(
|
1296
|
+
cls,
|
1280
1297
|
checkpoint,
|
1281
|
-
|
1282
|
-
|
1283
|
-
extract_ema=False,
|
1284
|
-
image_size=None,
|
1298
|
+
subfolder="",
|
1299
|
+
config=None,
|
1285
1300
|
torch_dtype=None,
|
1286
|
-
|
1301
|
+
local_files_only=None,
|
1302
|
+
is_legacy_loading=False,
|
1287
1303
|
):
|
1288
|
-
|
1304
|
+
if config:
|
1305
|
+
config = {"pretrained_model_name_or_path": config}
|
1306
|
+
else:
|
1307
|
+
config = fetch_diffusers_config(checkpoint)
|
1289
1308
|
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1309
|
+
# For backwards compatibility
|
1310
|
+
# Older versions of `from_single_file` expected CLIP configs to be placed in their original transformers model repo
|
1311
|
+
# in the cache_dir, rather than in a subfolder of the Diffusers model
|
1312
|
+
if is_legacy_loading:
|
1313
|
+
logger.warning(
|
1314
|
+
(
|
1315
|
+
"Detected legacy CLIP loading behavior. Please run `from_single_file` with `local_files_only=False once to update "
|
1316
|
+
"the local cache directory with the necessary CLIP model config files. "
|
1317
|
+
"Attempting to load CLIP model from legacy cache directory."
|
1318
|
+
)
|
1319
|
+
)
|
1298
1320
|
|
1299
|
-
|
1300
|
-
|
1321
|
+
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
|
1322
|
+
clip_config = "openai/clip-vit-large-patch14"
|
1323
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1324
|
+
subfolder = ""
|
1301
1325
|
|
1302
|
-
|
1303
|
-
|
1326
|
+
elif is_open_clip_model(checkpoint):
|
1327
|
+
clip_config = "stabilityai/stable-diffusion-2"
|
1328
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1329
|
+
subfolder = "text_encoder"
|
1304
1330
|
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
unet_config["in_channels"] = num_in_channels
|
1310
|
-
if upcast_attention is not None:
|
1311
|
-
unet_config["upcast_attention"] = upcast_attention
|
1331
|
+
else:
|
1332
|
+
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
1333
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1334
|
+
subfolder = ""
|
1312
1335
|
|
1313
|
-
|
1336
|
+
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
|
1314
1337
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1315
|
-
|
1316
1338
|
with ctx():
|
1317
|
-
|
1339
|
+
model = cls(model_config)
|
1318
1340
|
|
1319
|
-
|
1320
|
-
from ..models.modeling_utils import load_model_dict_into_meta
|
1341
|
+
position_embedding_dim = model.text_model.embeddings.position_embedding.weight.shape[-1]
|
1321
1342
|
|
1322
|
-
|
1323
|
-
|
1324
|
-
for pat in unet._keys_to_ignore_on_load_unexpected:
|
1325
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1326
|
-
|
1327
|
-
if len(unexpected_keys) > 0:
|
1328
|
-
logger.warning(
|
1329
|
-
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1330
|
-
)
|
1331
|
-
else:
|
1332
|
-
unet.load_state_dict(diffusers_format_unet_checkpoint)
|
1333
|
-
|
1334
|
-
if torch_dtype is not None:
|
1335
|
-
unet = unet.to(torch_dtype)
|
1343
|
+
if is_clip_model(checkpoint):
|
1344
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
1336
1345
|
|
1337
|
-
|
1346
|
+
elif (
|
1347
|
+
is_clip_sdxl_model(checkpoint)
|
1348
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sdxl"]].shape[-1] == position_embedding_dim
|
1349
|
+
):
|
1350
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
1338
1351
|
|
1352
|
+
elif is_open_clip_model(checkpoint):
|
1353
|
+
prefix = "cond_stage_model.model."
|
1354
|
+
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1339
1355
|
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
torch_dtype=None,
|
1347
|
-
model_type=None,
|
1348
|
-
):
|
1349
|
-
# import here to avoid circular imports
|
1350
|
-
from ..models import AutoencoderKL
|
1356
|
+
elif (
|
1357
|
+
is_open_clip_sdxl_model(checkpoint)
|
1358
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sdxl"]].shape[-1] == position_embedding_dim
|
1359
|
+
):
|
1360
|
+
prefix = "conditioner.embedders.1.model."
|
1361
|
+
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1351
1362
|
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
model_type = infer_model_type(original_config, checkpoint, model_type)
|
1363
|
+
elif is_open_clip_sdxl_refiner_model(checkpoint):
|
1364
|
+
prefix = "conditioner.embedders.0.model."
|
1365
|
+
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1356
1366
|
|
1357
|
-
if model_type == "Playground":
|
1358
|
-
edm_mean = (
|
1359
|
-
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
|
1360
|
-
)
|
1361
|
-
edm_std = (
|
1362
|
-
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
|
1363
|
-
)
|
1364
1367
|
else:
|
1365
|
-
|
1366
|
-
edm_std = None
|
1367
|
-
|
1368
|
-
vae_config = create_vae_diffusers_config(
|
1369
|
-
original_config,
|
1370
|
-
image_size=image_size,
|
1371
|
-
scaling_factor=scaling_factor,
|
1372
|
-
latents_mean=edm_mean,
|
1373
|
-
latents_std=edm_std,
|
1374
|
-
)
|
1375
|
-
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
1376
|
-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1377
|
-
|
1378
|
-
with ctx():
|
1379
|
-
vae = AutoencoderKL(**vae_config)
|
1368
|
+
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
1380
1369
|
|
1381
1370
|
if is_accelerate_available():
|
1382
|
-
|
1383
|
-
|
1384
|
-
|
1385
|
-
if vae._keys_to_ignore_on_load_unexpected is not None:
|
1386
|
-
for pat in vae._keys_to_ignore_on_load_unexpected:
|
1371
|
+
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1372
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
1373
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
1387
1374
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1388
1375
|
|
1389
1376
|
if len(unexpected_keys) > 0:
|
1390
1377
|
logger.warning(
|
1391
|
-
f"Some weights of the model checkpoint were not used when initializing {
|
1378
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1392
1379
|
)
|
1380
|
+
|
1393
1381
|
else:
|
1394
|
-
|
1382
|
+
model.load_state_dict(diffusers_format_checkpoint)
|
1395
1383
|
|
1396
1384
|
if torch_dtype is not None:
|
1397
|
-
|
1385
|
+
model.to(torch_dtype)
|
1398
1386
|
|
1399
|
-
|
1387
|
+
model.eval()
|
1400
1388
|
|
1389
|
+
return model
|
1401
1390
|
|
1402
|
-
|
1403
|
-
|
1391
|
+
|
1392
|
+
def _legacy_load_scheduler(
|
1393
|
+
cls,
|
1404
1394
|
checkpoint,
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1395
|
+
component_name,
|
1396
|
+
original_config=None,
|
1397
|
+
**kwargs,
|
1408
1398
|
):
|
1409
|
-
|
1399
|
+
scheduler_type = kwargs.get("scheduler_type", None)
|
1400
|
+
prediction_type = kwargs.get("prediction_type", None)
|
1410
1401
|
|
1411
|
-
if
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
text_encoder = create_text_encoder_from_open_clip_checkpoint(
|
1417
|
-
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
|
1418
|
-
)
|
1419
|
-
tokenizer = CLIPTokenizer.from_pretrained(
|
1420
|
-
config_name, subfolder="tokenizer", local_files_only=local_files_only
|
1421
|
-
)
|
1422
|
-
except Exception:
|
1423
|
-
raise ValueError(
|
1424
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'."
|
1425
|
-
)
|
1426
|
-
else:
|
1427
|
-
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
|
1428
|
-
|
1429
|
-
elif model_type == "FrozenCLIPEmbedder":
|
1430
|
-
try:
|
1431
|
-
config_name = "openai/clip-vit-large-patch14"
|
1432
|
-
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
|
1433
|
-
config_name,
|
1434
|
-
checkpoint,
|
1435
|
-
local_files_only=local_files_only,
|
1436
|
-
torch_dtype=torch_dtype,
|
1437
|
-
)
|
1438
|
-
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
1439
|
-
|
1440
|
-
except Exception:
|
1441
|
-
raise ValueError(
|
1442
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
|
1443
|
-
)
|
1444
|
-
else:
|
1445
|
-
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
|
1446
|
-
|
1447
|
-
elif model_type == "SDXL-Refiner":
|
1448
|
-
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
1449
|
-
config_kwargs = {"projection_dim": 1280}
|
1450
|
-
prefix = "conditioner.embedders.0.model."
|
1451
|
-
|
1452
|
-
try:
|
1453
|
-
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
1454
|
-
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
|
1455
|
-
config_name,
|
1456
|
-
checkpoint,
|
1457
|
-
prefix=prefix,
|
1458
|
-
has_projection=True,
|
1459
|
-
local_files_only=local_files_only,
|
1460
|
-
torch_dtype=torch_dtype,
|
1461
|
-
**config_kwargs,
|
1462
|
-
)
|
1463
|
-
except Exception:
|
1464
|
-
raise ValueError(
|
1465
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
|
1466
|
-
)
|
1467
|
-
|
1468
|
-
else:
|
1469
|
-
return {
|
1470
|
-
"text_encoder": None,
|
1471
|
-
"tokenizer": None,
|
1472
|
-
"tokenizer_2": tokenizer_2,
|
1473
|
-
"text_encoder_2": text_encoder_2,
|
1474
|
-
}
|
1475
|
-
|
1476
|
-
elif model_type in ["SDXL", "Playground"]:
|
1477
|
-
try:
|
1478
|
-
config_name = "openai/clip-vit-large-patch14"
|
1479
|
-
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
1480
|
-
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
|
1481
|
-
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
|
1482
|
-
)
|
1483
|
-
|
1484
|
-
except Exception:
|
1485
|
-
raise ValueError(
|
1486
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
1487
|
-
)
|
1488
|
-
|
1489
|
-
try:
|
1490
|
-
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
1491
|
-
config_kwargs = {"projection_dim": 1280}
|
1492
|
-
prefix = "conditioner.embedders.1.model."
|
1493
|
-
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
1494
|
-
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
|
1495
|
-
config_name,
|
1496
|
-
checkpoint,
|
1497
|
-
prefix=prefix,
|
1498
|
-
has_projection=True,
|
1499
|
-
local_files_only=local_files_only,
|
1500
|
-
torch_dtype=torch_dtype,
|
1501
|
-
**config_kwargs,
|
1502
|
-
)
|
1503
|
-
except Exception:
|
1504
|
-
raise ValueError(
|
1505
|
-
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
|
1506
|
-
)
|
1507
|
-
|
1508
|
-
return {
|
1509
|
-
"tokenizer": tokenizer,
|
1510
|
-
"text_encoder": text_encoder,
|
1511
|
-
"tokenizer_2": tokenizer_2,
|
1512
|
-
"text_encoder_2": text_encoder_2,
|
1513
|
-
}
|
1514
|
-
|
1515
|
-
return
|
1402
|
+
if scheduler_type is not None:
|
1403
|
+
deprecation_message = (
|
1404
|
+
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
|
1405
|
+
)
|
1406
|
+
deprecate("scheduler_type", "1.0.0", deprecation_message)
|
1516
1407
|
|
1408
|
+
if prediction_type is not None:
|
1409
|
+
deprecation_message = (
|
1410
|
+
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
|
1411
|
+
"and pass the object directly to the `scheduler` argument in `from_single_file`."
|
1412
|
+
)
|
1413
|
+
deprecate("prediction_type", "1.0.0", deprecation_message)
|
1517
1414
|
|
1518
|
-
|
1519
|
-
|
1520
|
-
original_config,
|
1521
|
-
checkpoint,
|
1522
|
-
prediction_type=None,
|
1523
|
-
scheduler_type="ddim",
|
1524
|
-
model_type=None,
|
1525
|
-
):
|
1526
|
-
scheduler_config = get_default_scheduler_config()
|
1527
|
-
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
1415
|
+
scheduler_config = SCHEDULER_DEFAULT_CONFIG
|
1416
|
+
model_type = infer_diffusers_model_type(checkpoint=checkpoint)
|
1528
1417
|
|
1529
1418
|
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
1530
1419
|
|
1531
|
-
|
1420
|
+
if original_config:
|
1421
|
+
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", 1000)
|
1422
|
+
else:
|
1423
|
+
num_train_timesteps = 1000
|
1424
|
+
|
1532
1425
|
scheduler_config["num_train_timesteps"] = num_train_timesteps
|
1533
1426
|
|
1534
|
-
if
|
1535
|
-
"parameterization" in original_config["model"]["params"]
|
1536
|
-
and original_config["model"]["params"]["parameterization"] == "v"
|
1537
|
-
):
|
1427
|
+
if model_type == "v2":
|
1538
1428
|
if prediction_type is None:
|
1539
|
-
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
1540
|
-
# as it relies on a brittle global step parameter here
|
1429
|
+
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` # as it relies on a brittle global step parameter here
|
1541
1430
|
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
1542
1431
|
|
1543
1432
|
else:
|
@@ -1545,20 +1434,44 @@ def create_scheduler_from_ldm(
|
|
1545
1434
|
|
1546
1435
|
scheduler_config["prediction_type"] = prediction_type
|
1547
1436
|
|
1548
|
-
if model_type in ["
|
1437
|
+
if model_type in ["xl_base", "xl_refiner"]:
|
1549
1438
|
scheduler_type = "euler"
|
1550
|
-
elif model_type == "
|
1439
|
+
elif model_type == "playground":
|
1551
1440
|
scheduler_type = "edm_dpm_solver_multistep"
|
1552
1441
|
else:
|
1553
|
-
|
1554
|
-
|
1442
|
+
if original_config:
|
1443
|
+
beta_start = original_config["model"]["params"].get("linear_start")
|
1444
|
+
beta_end = original_config["model"]["params"].get("linear_end")
|
1445
|
+
|
1446
|
+
else:
|
1447
|
+
beta_start = 0.02
|
1448
|
+
beta_end = 0.085
|
1449
|
+
|
1555
1450
|
scheduler_config["beta_start"] = beta_start
|
1556
1451
|
scheduler_config["beta_end"] = beta_end
|
1557
1452
|
scheduler_config["beta_schedule"] = "scaled_linear"
|
1558
1453
|
scheduler_config["clip_sample"] = False
|
1559
1454
|
scheduler_config["set_alpha_to_one"] = False
|
1560
1455
|
|
1561
|
-
|
1456
|
+
# to deal with an edge case StableDiffusionUpscale pipeline has two schedulers
|
1457
|
+
if component_name == "low_res_scheduler":
|
1458
|
+
return cls.from_config(
|
1459
|
+
{
|
1460
|
+
"beta_end": 0.02,
|
1461
|
+
"beta_schedule": "scaled_linear",
|
1462
|
+
"beta_start": 0.0001,
|
1463
|
+
"clip_sample": True,
|
1464
|
+
"num_train_timesteps": 1000,
|
1465
|
+
"prediction_type": "epsilon",
|
1466
|
+
"trained_betas": None,
|
1467
|
+
"variance_type": "fixed_small",
|
1468
|
+
}
|
1469
|
+
)
|
1470
|
+
|
1471
|
+
if scheduler_type is None:
|
1472
|
+
return cls.from_config(scheduler_config)
|
1473
|
+
|
1474
|
+
elif scheduler_type == "pndm":
|
1562
1475
|
scheduler_config["skip_prk_steps"] = True
|
1563
1476
|
scheduler = PNDMScheduler.from_config(scheduler_config)
|
1564
1477
|
|
@@ -1603,15 +1516,46 @@ def create_scheduler_from_ldm(
|
|
1603
1516
|
else:
|
1604
1517
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
1605
1518
|
|
1606
|
-
|
1607
|
-
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", subfolder="scheduler")
|
1608
|
-
low_res_scheduler = DDPMScheduler.from_pretrained(
|
1609
|
-
"stabilityai/stable-diffusion-x4-upscaler", subfolder="low_res_scheduler"
|
1610
|
-
)
|
1519
|
+
return scheduler
|
1611
1520
|
|
1612
|
-
return {
|
1613
|
-
"scheduler": scheduler,
|
1614
|
-
"low_res_scheduler": low_res_scheduler,
|
1615
|
-
}
|
1616
1521
|
|
1617
|
-
|
1522
|
+
def _legacy_load_clip_tokenizer(cls, checkpoint, config=None, local_files_only=False):
|
1523
|
+
if config:
|
1524
|
+
config = {"pretrained_model_name_or_path": config}
|
1525
|
+
else:
|
1526
|
+
config = fetch_diffusers_config(checkpoint)
|
1527
|
+
|
1528
|
+
if is_clip_model(checkpoint) or is_clip_sdxl_model(checkpoint):
|
1529
|
+
clip_config = "openai/clip-vit-large-patch14"
|
1530
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1531
|
+
subfolder = ""
|
1532
|
+
|
1533
|
+
elif is_open_clip_model(checkpoint):
|
1534
|
+
clip_config = "stabilityai/stable-diffusion-2"
|
1535
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1536
|
+
subfolder = "tokenizer"
|
1537
|
+
|
1538
|
+
else:
|
1539
|
+
clip_config = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
1540
|
+
config["pretrained_model_name_or_path"] = clip_config
|
1541
|
+
subfolder = ""
|
1542
|
+
|
1543
|
+
tokenizer = cls.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
|
1544
|
+
|
1545
|
+
return tokenizer
|
1546
|
+
|
1547
|
+
|
1548
|
+
def _legacy_load_safety_checker(local_files_only, torch_dtype):
|
1549
|
+
# Support for loading safety checker components using the deprecated
|
1550
|
+
# `load_safety_checker` argument.
|
1551
|
+
|
1552
|
+
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
1553
|
+
|
1554
|
+
feature_extractor = AutoImageProcessor.from_pretrained(
|
1555
|
+
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
|
1556
|
+
)
|
1557
|
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
1558
|
+
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
|
1559
|
+
)
|
1560
|
+
|
1561
|
+
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|