diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +28 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +278 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
- diffusers-0.27.0.dist-info/RECORD +399 -0
- diffusers-0.26.2.dist-info/RECORD +0 -384
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -28,6 +28,7 @@ from ..schedulers import (
|
|
28
28
|
DDIMScheduler,
|
29
29
|
DDPMScheduler,
|
30
30
|
DPMSolverMultistepScheduler,
|
31
|
+
EDMDPMSolverMultistepScheduler,
|
31
32
|
EulerAncestralDiscreteScheduler,
|
32
33
|
EulerDiscreteScheduler,
|
33
34
|
HeunDiscreteScheduler,
|
@@ -48,7 +49,6 @@ if is_transformers_available():
|
|
48
49
|
|
49
50
|
if is_accelerate_available():
|
50
51
|
from accelerate import init_empty_weights
|
51
|
-
from accelerate.utils import set_module_tensor_to_device
|
52
52
|
|
53
53
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
54
54
|
|
@@ -81,6 +81,87 @@ SCHEDULER_DEFAULT_CONFIG = {
|
|
81
81
|
"timestep_spacing": "leading",
|
82
82
|
}
|
83
83
|
|
84
|
+
|
85
|
+
STABLE_CASCADE_DEFAULT_CONFIGS = {
|
86
|
+
"stage_c": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior"},
|
87
|
+
"stage_c_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "prior_lite"},
|
88
|
+
"stage_b": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder"},
|
89
|
+
"stage_b_lite": {"pretrained_model_name_or_path": "diffusers/stable-cascade-configs", "subfolder": "decoder_lite"},
|
90
|
+
}
|
91
|
+
|
92
|
+
|
93
|
+
def convert_stable_cascade_unet_single_file_to_diffusers(original_state_dict):
|
94
|
+
is_stage_c = "clip_txt_mapper.weight" in original_state_dict
|
95
|
+
|
96
|
+
if is_stage_c:
|
97
|
+
state_dict = {}
|
98
|
+
for key in original_state_dict.keys():
|
99
|
+
if key.endswith("in_proj_weight"):
|
100
|
+
weights = original_state_dict[key].chunk(3, 0)
|
101
|
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
102
|
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
103
|
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
104
|
+
elif key.endswith("in_proj_bias"):
|
105
|
+
weights = original_state_dict[key].chunk(3, 0)
|
106
|
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
107
|
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
108
|
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
109
|
+
elif key.endswith("out_proj.weight"):
|
110
|
+
weights = original_state_dict[key]
|
111
|
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
112
|
+
elif key.endswith("out_proj.bias"):
|
113
|
+
weights = original_state_dict[key]
|
114
|
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
115
|
+
else:
|
116
|
+
state_dict[key] = original_state_dict[key]
|
117
|
+
else:
|
118
|
+
state_dict = {}
|
119
|
+
for key in original_state_dict.keys():
|
120
|
+
if key.endswith("in_proj_weight"):
|
121
|
+
weights = original_state_dict[key].chunk(3, 0)
|
122
|
+
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
|
123
|
+
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
|
124
|
+
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
|
125
|
+
elif key.endswith("in_proj_bias"):
|
126
|
+
weights = original_state_dict[key].chunk(3, 0)
|
127
|
+
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
|
128
|
+
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
|
129
|
+
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
|
130
|
+
elif key.endswith("out_proj.weight"):
|
131
|
+
weights = original_state_dict[key]
|
132
|
+
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
|
133
|
+
elif key.endswith("out_proj.bias"):
|
134
|
+
weights = original_state_dict[key]
|
135
|
+
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
|
136
|
+
# rename clip_mapper to clip_txt_pooled_mapper
|
137
|
+
elif key.endswith("clip_mapper.weight"):
|
138
|
+
weights = original_state_dict[key]
|
139
|
+
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
|
140
|
+
elif key.endswith("clip_mapper.bias"):
|
141
|
+
weights = original_state_dict[key]
|
142
|
+
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
|
143
|
+
else:
|
144
|
+
state_dict[key] = original_state_dict[key]
|
145
|
+
|
146
|
+
return state_dict
|
147
|
+
|
148
|
+
|
149
|
+
def infer_stable_cascade_single_file_config(checkpoint):
|
150
|
+
is_stage_c = "clip_txt_mapper.weight" in checkpoint
|
151
|
+
is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
|
152
|
+
|
153
|
+
if is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 1536):
|
154
|
+
config_type = "stage_c_lite"
|
155
|
+
elif is_stage_c and (checkpoint["clip_txt_mapper.weight"].shape[0] == 2048):
|
156
|
+
config_type = "stage_c"
|
157
|
+
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 576:
|
158
|
+
config_type = "stage_b_lite"
|
159
|
+
elif is_stage_b and checkpoint["down_blocks.1.0.channelwise.0.weight"].shape[-1] == 640:
|
160
|
+
config_type = "stage_b"
|
161
|
+
|
162
|
+
return STABLE_CASCADE_DEFAULT_CONFIGS[config_type]
|
163
|
+
|
164
|
+
|
84
165
|
DIFFUSERS_TO_LDM_MAPPING = {
|
85
166
|
"unet": {
|
86
167
|
"layers": {
|
@@ -175,6 +256,8 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
|
175
256
|
}
|
176
257
|
|
177
258
|
LDM_VAE_KEY = "first_stage_model."
|
259
|
+
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
260
|
+
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
178
261
|
LDM_UNET_KEY = "model.diffusion_model."
|
179
262
|
LDM_CONTROLNET_KEY = "control_model."
|
180
263
|
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
@@ -227,17 +310,34 @@ def fetch_ldm_config_and_checkpoint(
|
|
227
310
|
cache_dir=None,
|
228
311
|
local_files_only=None,
|
229
312
|
revision=None,
|
230
|
-
use_safetensors=True,
|
231
313
|
):
|
232
|
-
|
233
|
-
|
314
|
+
checkpoint = load_single_file_model_checkpoint(
|
315
|
+
pretrained_model_link_or_path,
|
316
|
+
resume_download=resume_download,
|
317
|
+
force_download=force_download,
|
318
|
+
proxies=proxies,
|
319
|
+
token=token,
|
320
|
+
cache_dir=cache_dir,
|
321
|
+
local_files_only=local_files_only,
|
322
|
+
revision=revision,
|
323
|
+
)
|
324
|
+
original_config = fetch_original_config(class_name, checkpoint, original_config_file)
|
325
|
+
|
326
|
+
return original_config, checkpoint
|
234
327
|
|
235
|
-
if from_safetensors and use_safetensors is False:
|
236
|
-
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
237
328
|
|
329
|
+
def load_single_file_model_checkpoint(
|
330
|
+
pretrained_model_link_or_path,
|
331
|
+
resume_download=False,
|
332
|
+
force_download=False,
|
333
|
+
proxies=None,
|
334
|
+
token=None,
|
335
|
+
cache_dir=None,
|
336
|
+
local_files_only=None,
|
337
|
+
revision=None,
|
338
|
+
):
|
238
339
|
if os.path.isfile(pretrained_model_link_or_path):
|
239
340
|
checkpoint = load_state_dict(pretrained_model_link_or_path)
|
240
|
-
|
241
341
|
else:
|
242
342
|
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
243
343
|
checkpoint_path = _get_model_file(
|
@@ -257,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
|
|
257
357
|
while "state_dict" in checkpoint:
|
258
358
|
checkpoint = checkpoint["state_dict"]
|
259
359
|
|
260
|
-
|
261
|
-
|
262
|
-
return original_config, checkpoint
|
360
|
+
return checkpoint
|
263
361
|
|
264
362
|
|
265
363
|
def infer_original_config_file(class_name, checkpoint):
|
@@ -312,7 +410,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
|
312
410
|
return original_config
|
313
411
|
|
314
412
|
|
315
|
-
def infer_model_type(original_config, model_type=None):
|
413
|
+
def infer_model_type(original_config, checkpoint, model_type=None):
|
316
414
|
if model_type is not None:
|
317
415
|
return model_type
|
318
416
|
|
@@ -330,7 +428,9 @@ def infer_model_type(original_config, model_type=None):
|
|
330
428
|
|
331
429
|
elif has_network_config:
|
332
430
|
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
|
333
|
-
if
|
431
|
+
if "edm_mean" in checkpoint and "edm_std" in checkpoint:
|
432
|
+
model_type = "Playground"
|
433
|
+
elif context_dim == 2048:
|
334
434
|
model_type = "SDXL"
|
335
435
|
else:
|
336
436
|
model_type = "SDXL-Refiner"
|
@@ -351,13 +451,13 @@ def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=
|
|
351
451
|
return image_size
|
352
452
|
|
353
453
|
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
354
|
-
model_type = infer_model_type(original_config, model_type)
|
454
|
+
model_type = infer_model_type(original_config, checkpoint, model_type)
|
355
455
|
|
356
456
|
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
357
457
|
image_size = original_config["model"]["params"]["unet_config"]["params"]["image_size"]
|
358
458
|
return image_size
|
359
459
|
|
360
|
-
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
460
|
+
elif model_type in ["SDXL", "SDXL-Refiner", "Playground"]:
|
361
461
|
image_size = 1024
|
362
462
|
return image_size
|
363
463
|
|
@@ -465,8 +565,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
465
565
|
config = {
|
466
566
|
"sample_size": image_size // vae_scale_factor,
|
467
567
|
"in_channels": unet_params["in_channels"],
|
468
|
-
"down_block_types":
|
469
|
-
"block_out_channels":
|
568
|
+
"down_block_types": down_block_types,
|
569
|
+
"block_out_channels": block_out_channels,
|
470
570
|
"layers_per_block": unet_params["num_res_blocks"],
|
471
571
|
"cross_attention_dim": context_dim,
|
472
572
|
"attention_head_dim": head_dim,
|
@@ -485,7 +585,7 @@ def create_unet_diffusers_config(original_config, image_size: int):
|
|
485
585
|
config["num_class_embeds"] = unet_params["num_classes"]
|
486
586
|
|
487
587
|
config["out_channels"] = unet_params["out_channels"]
|
488
|
-
config["up_block_types"] =
|
588
|
+
config["up_block_types"] = up_block_types
|
489
589
|
|
490
590
|
return config
|
491
591
|
|
@@ -513,12 +613,17 @@ def create_controlnet_diffusers_config(original_config, image_size: int):
|
|
513
613
|
return controlnet_config
|
514
614
|
|
515
615
|
|
516
|
-
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None):
|
616
|
+
def create_vae_diffusers_config(original_config, image_size, scaling_factor=None, latents_mean=None, latents_std=None):
|
517
617
|
"""
|
518
618
|
Creates a config for the diffusers based on the config of the LDM model.
|
519
619
|
"""
|
520
620
|
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
521
|
-
scaling_factor
|
621
|
+
if (scaling_factor is None) and (latents_mean is not None) and (latents_std is not None):
|
622
|
+
scaling_factor = PLAYGROUND_VAE_SCALING_FACTOR
|
623
|
+
elif (scaling_factor is None) and ("scale_factor" in original_config["model"]["params"]):
|
624
|
+
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
625
|
+
elif scaling_factor is None:
|
626
|
+
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
522
627
|
|
523
628
|
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
524
629
|
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
@@ -528,13 +633,15 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
|
528
633
|
"sample_size": image_size,
|
529
634
|
"in_channels": vae_params["in_channels"],
|
530
635
|
"out_channels": vae_params["out_ch"],
|
531
|
-
"down_block_types":
|
532
|
-
"up_block_types":
|
533
|
-
"block_out_channels":
|
636
|
+
"down_block_types": down_block_types,
|
637
|
+
"up_block_types": up_block_types,
|
638
|
+
"block_out_channels": block_out_channels,
|
534
639
|
"latent_channels": vae_params["z_channels"],
|
535
640
|
"layers_per_block": vae_params["num_res_blocks"],
|
536
641
|
"scaling_factor": scaling_factor,
|
537
642
|
}
|
643
|
+
if latents_mean is not None and latents_std is not None:
|
644
|
+
config.update({"latents_mean": latents_mean, "latents_std": latents_std})
|
538
645
|
|
539
646
|
return config
|
540
647
|
|
@@ -853,7 +960,7 @@ def convert_controlnet_checkpoint(
|
|
853
960
|
|
854
961
|
|
855
962
|
def create_diffusers_controlnet_model_from_ldm(
|
856
|
-
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None
|
963
|
+
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None, torch_dtype=None
|
857
964
|
):
|
858
965
|
# import here to avoid circular imports
|
859
966
|
from ..models import ControlNetModel
|
@@ -870,11 +977,25 @@ def create_diffusers_controlnet_model_from_ldm(
|
|
870
977
|
controlnet = ControlNetModel(**diffusers_config)
|
871
978
|
|
872
979
|
if is_accelerate_available():
|
873
|
-
|
874
|
-
|
980
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
981
|
+
|
982
|
+
unexpected_keys = load_model_dict_into_meta(
|
983
|
+
controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
|
984
|
+
)
|
985
|
+
if controlnet._keys_to_ignore_on_load_unexpected is not None:
|
986
|
+
for pat in controlnet._keys_to_ignore_on_load_unexpected:
|
987
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
988
|
+
|
989
|
+
if len(unexpected_keys) > 0:
|
990
|
+
logger.warning(
|
991
|
+
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
992
|
+
)
|
875
993
|
else:
|
876
994
|
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
|
877
995
|
|
996
|
+
if torch_dtype is not None:
|
997
|
+
controlnet = controlnet.to(torch_dtype)
|
998
|
+
|
878
999
|
return {"controlnet": controlnet}
|
879
1000
|
|
880
1001
|
|
@@ -1010,7 +1131,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1010
1131
|
return new_checkpoint
|
1011
1132
|
|
1012
1133
|
|
1013
|
-
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False):
|
1134
|
+
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False, torch_dtype=None):
|
1014
1135
|
try:
|
1015
1136
|
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
1016
1137
|
except Exception:
|
@@ -1034,14 +1155,26 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
|
1034
1155
|
text_model_dict[diffusers_key] = checkpoint[key]
|
1035
1156
|
|
1036
1157
|
if is_accelerate_available():
|
1037
|
-
|
1038
|
-
|
1158
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
1159
|
+
|
1160
|
+
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
1161
|
+
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
1162
|
+
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
1163
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1164
|
+
|
1165
|
+
if len(unexpected_keys) > 0:
|
1166
|
+
logger.warning(
|
1167
|
+
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1168
|
+
)
|
1039
1169
|
else:
|
1040
1170
|
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
1041
1171
|
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1042
1172
|
|
1043
1173
|
text_model.load_state_dict(text_model_dict)
|
1044
1174
|
|
1175
|
+
if torch_dtype is not None:
|
1176
|
+
text_model = text_model.to(torch_dtype)
|
1177
|
+
|
1045
1178
|
return text_model
|
1046
1179
|
|
1047
1180
|
|
@@ -1051,6 +1184,7 @@ def create_text_encoder_from_open_clip_checkpoint(
|
|
1051
1184
|
prefix="cond_stage_model.model.",
|
1052
1185
|
has_projection=False,
|
1053
1186
|
local_files_only=False,
|
1187
|
+
torch_dtype=None,
|
1054
1188
|
**config_kwargs,
|
1055
1189
|
):
|
1056
1190
|
try:
|
@@ -1112,13 +1246,21 @@ def create_text_encoder_from_open_clip_checkpoint(
|
|
1112
1246
|
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
|
1113
1247
|
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
|
1114
1248
|
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
|
1115
|
-
|
1116
1249
|
else:
|
1117
1250
|
text_model_dict[diffusers_key] = checkpoint[key]
|
1118
1251
|
|
1119
1252
|
if is_accelerate_available():
|
1120
|
-
|
1121
|
-
|
1253
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
1254
|
+
|
1255
|
+
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
1256
|
+
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
1257
|
+
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
1258
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1259
|
+
|
1260
|
+
if len(unexpected_keys) > 0:
|
1261
|
+
logger.warning(
|
1262
|
+
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1263
|
+
)
|
1122
1264
|
|
1123
1265
|
else:
|
1124
1266
|
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
@@ -1126,6 +1268,9 @@ def create_text_encoder_from_open_clip_checkpoint(
|
|
1126
1268
|
|
1127
1269
|
text_model.load_state_dict(text_model_dict)
|
1128
1270
|
|
1271
|
+
if torch_dtype is not None:
|
1272
|
+
text_model = text_model.to(torch_dtype)
|
1273
|
+
|
1129
1274
|
return text_model
|
1130
1275
|
|
1131
1276
|
|
@@ -1134,15 +1279,18 @@ def create_diffusers_unet_model_from_ldm(
|
|
1134
1279
|
original_config,
|
1135
1280
|
checkpoint,
|
1136
1281
|
num_in_channels=None,
|
1137
|
-
upcast_attention=
|
1282
|
+
upcast_attention=None,
|
1138
1283
|
extract_ema=False,
|
1139
1284
|
image_size=None,
|
1285
|
+
torch_dtype=None,
|
1286
|
+
model_type=None,
|
1140
1287
|
):
|
1141
1288
|
from ..models import UNet2DConditionModel
|
1142
1289
|
|
1143
1290
|
if num_in_channels is None:
|
1144
1291
|
if pipeline_class_name in [
|
1145
1292
|
"StableDiffusionInpaintPipeline",
|
1293
|
+
"StableDiffusionControlNetInpaintPipeline",
|
1146
1294
|
"StableDiffusionXLInpaintPipeline",
|
1147
1295
|
"StableDiffusionXLControlNetInpaintPipeline",
|
1148
1296
|
]:
|
@@ -1154,34 +1302,76 @@ def create_diffusers_unet_model_from_ldm(
|
|
1154
1302
|
else:
|
1155
1303
|
num_in_channels = 4
|
1156
1304
|
|
1157
|
-
image_size = set_image_size(
|
1305
|
+
image_size = set_image_size(
|
1306
|
+
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
1307
|
+
)
|
1158
1308
|
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
1159
1309
|
unet_config["in_channels"] = num_in_channels
|
1160
|
-
|
1310
|
+
if upcast_attention is not None:
|
1311
|
+
unet_config["upcast_attention"] = upcast_attention
|
1161
1312
|
|
1162
1313
|
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
|
1163
1314
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1315
|
+
|
1164
1316
|
with ctx():
|
1165
1317
|
unet = UNet2DConditionModel(**unet_config)
|
1166
1318
|
|
1167
1319
|
if is_accelerate_available():
|
1168
|
-
|
1169
|
-
|
1320
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
1321
|
+
|
1322
|
+
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
|
1323
|
+
if unet._keys_to_ignore_on_load_unexpected is not None:
|
1324
|
+
for pat in unet._keys_to_ignore_on_load_unexpected:
|
1325
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1326
|
+
|
1327
|
+
if len(unexpected_keys) > 0:
|
1328
|
+
logger.warning(
|
1329
|
+
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1330
|
+
)
|
1170
1331
|
else:
|
1171
1332
|
unet.load_state_dict(diffusers_format_unet_checkpoint)
|
1172
1333
|
|
1334
|
+
if torch_dtype is not None:
|
1335
|
+
unet = unet.to(torch_dtype)
|
1336
|
+
|
1173
1337
|
return {"unet": unet}
|
1174
1338
|
|
1175
1339
|
|
1176
1340
|
def create_diffusers_vae_model_from_ldm(
|
1177
|
-
pipeline_class_name,
|
1341
|
+
pipeline_class_name,
|
1342
|
+
original_config,
|
1343
|
+
checkpoint,
|
1344
|
+
image_size=None,
|
1345
|
+
scaling_factor=None,
|
1346
|
+
torch_dtype=None,
|
1347
|
+
model_type=None,
|
1178
1348
|
):
|
1179
1349
|
# import here to avoid circular imports
|
1180
1350
|
from ..models import AutoencoderKL
|
1181
1351
|
|
1182
|
-
image_size = set_image_size(
|
1352
|
+
image_size = set_image_size(
|
1353
|
+
pipeline_class_name, original_config, checkpoint, image_size=image_size, model_type=model_type
|
1354
|
+
)
|
1355
|
+
model_type = infer_model_type(original_config, checkpoint, model_type)
|
1183
1356
|
|
1184
|
-
|
1357
|
+
if model_type == "Playground":
|
1358
|
+
edm_mean = (
|
1359
|
+
checkpoint["edm_mean"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_mean"].tolist()
|
1360
|
+
)
|
1361
|
+
edm_std = (
|
1362
|
+
checkpoint["edm_std"].to(dtype=torch_dtype).tolist() if torch_dtype else checkpoint["edm_std"].tolist()
|
1363
|
+
)
|
1364
|
+
else:
|
1365
|
+
edm_mean = None
|
1366
|
+
edm_std = None
|
1367
|
+
|
1368
|
+
vae_config = create_vae_diffusers_config(
|
1369
|
+
original_config,
|
1370
|
+
image_size=image_size,
|
1371
|
+
scaling_factor=scaling_factor,
|
1372
|
+
latents_mean=edm_mean,
|
1373
|
+
latents_std=edm_std,
|
1374
|
+
)
|
1185
1375
|
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
1186
1376
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
1187
1377
|
|
@@ -1189,11 +1379,23 @@ def create_diffusers_vae_model_from_ldm(
|
|
1189
1379
|
vae = AutoencoderKL(**vae_config)
|
1190
1380
|
|
1191
1381
|
if is_accelerate_available():
|
1192
|
-
|
1193
|
-
|
1382
|
+
from ..models.modeling_utils import load_model_dict_into_meta
|
1383
|
+
|
1384
|
+
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
|
1385
|
+
if vae._keys_to_ignore_on_load_unexpected is not None:
|
1386
|
+
for pat in vae._keys_to_ignore_on_load_unexpected:
|
1387
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1388
|
+
|
1389
|
+
if len(unexpected_keys) > 0:
|
1390
|
+
logger.warning(
|
1391
|
+
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1392
|
+
)
|
1194
1393
|
else:
|
1195
1394
|
vae.load_state_dict(diffusers_format_vae_checkpoint)
|
1196
1395
|
|
1396
|
+
if torch_dtype is not None:
|
1397
|
+
vae = vae.to(torch_dtype)
|
1398
|
+
|
1197
1399
|
return {"vae": vae}
|
1198
1400
|
|
1199
1401
|
|
@@ -1202,8 +1404,9 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1202
1404
|
checkpoint,
|
1203
1405
|
model_type=None,
|
1204
1406
|
local_files_only=False,
|
1407
|
+
torch_dtype=None,
|
1205
1408
|
):
|
1206
|
-
model_type = infer_model_type(original_config, model_type=model_type)
|
1409
|
+
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
1207
1410
|
|
1208
1411
|
if model_type == "FrozenOpenCLIPEmbedder":
|
1209
1412
|
config_name = "stabilityai/stable-diffusion-2"
|
@@ -1211,7 +1414,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1211
1414
|
|
1212
1415
|
try:
|
1213
1416
|
text_encoder = create_text_encoder_from_open_clip_checkpoint(
|
1214
|
-
config_name, checkpoint, local_files_only=local_files_only, **config_kwargs
|
1417
|
+
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype, **config_kwargs
|
1215
1418
|
)
|
1216
1419
|
tokenizer = CLIPTokenizer.from_pretrained(
|
1217
1420
|
config_name, subfolder="tokenizer", local_files_only=local_files_only
|
@@ -1227,7 +1430,10 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1227
1430
|
try:
|
1228
1431
|
config_name = "openai/clip-vit-large-patch14"
|
1229
1432
|
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
|
1230
|
-
config_name,
|
1433
|
+
config_name,
|
1434
|
+
checkpoint,
|
1435
|
+
local_files_only=local_files_only,
|
1436
|
+
torch_dtype=torch_dtype,
|
1231
1437
|
)
|
1232
1438
|
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
1233
1439
|
|
@@ -1251,6 +1457,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1251
1457
|
prefix=prefix,
|
1252
1458
|
has_projection=True,
|
1253
1459
|
local_files_only=local_files_only,
|
1460
|
+
torch_dtype=torch_dtype,
|
1254
1461
|
**config_kwargs,
|
1255
1462
|
)
|
1256
1463
|
except Exception:
|
@@ -1266,12 +1473,12 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1266
1473
|
"text_encoder_2": text_encoder_2,
|
1267
1474
|
}
|
1268
1475
|
|
1269
|
-
elif model_type
|
1476
|
+
elif model_type in ["SDXL", "Playground"]:
|
1270
1477
|
try:
|
1271
1478
|
config_name = "openai/clip-vit-large-patch14"
|
1272
1479
|
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
1273
1480
|
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
|
1274
|
-
config_name, checkpoint, local_files_only=local_files_only
|
1481
|
+
config_name, checkpoint, local_files_only=local_files_only, torch_dtype=torch_dtype
|
1275
1482
|
)
|
1276
1483
|
|
1277
1484
|
except Exception:
|
@@ -1290,6 +1497,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
|
1290
1497
|
prefix=prefix,
|
1291
1498
|
has_projection=True,
|
1292
1499
|
local_files_only=local_files_only,
|
1500
|
+
torch_dtype=torch_dtype,
|
1293
1501
|
**config_kwargs,
|
1294
1502
|
)
|
1295
1503
|
except Exception:
|
@@ -1316,7 +1524,7 @@ def create_scheduler_from_ldm(
|
|
1316
1524
|
model_type=None,
|
1317
1525
|
):
|
1318
1526
|
scheduler_config = get_default_scheduler_config()
|
1319
|
-
model_type = infer_model_type(original_config, model_type=model_type)
|
1527
|
+
model_type = infer_model_type(original_config, checkpoint=checkpoint, model_type=model_type)
|
1320
1528
|
|
1321
1529
|
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
1322
1530
|
|
@@ -1339,7 +1547,8 @@ def create_scheduler_from_ldm(
|
|
1339
1547
|
|
1340
1548
|
if model_type in ["SDXL", "SDXL-Refiner"]:
|
1341
1549
|
scheduler_type = "euler"
|
1342
|
-
|
1550
|
+
elif model_type == "Playground":
|
1551
|
+
scheduler_type = "edm_dpm_solver_multistep"
|
1343
1552
|
else:
|
1344
1553
|
beta_start = original_config["model"]["params"].get("linear_start", 0.02)
|
1345
1554
|
beta_end = original_config["model"]["params"].get("linear_end", 0.085)
|
@@ -1371,6 +1580,26 @@ def create_scheduler_from_ldm(
|
|
1371
1580
|
elif scheduler_type == "ddim":
|
1372
1581
|
scheduler = DDIMScheduler.from_config(scheduler_config)
|
1373
1582
|
|
1583
|
+
elif scheduler_type == "edm_dpm_solver_multistep":
|
1584
|
+
scheduler_config = {
|
1585
|
+
"algorithm_type": "dpmsolver++",
|
1586
|
+
"dynamic_thresholding_ratio": 0.995,
|
1587
|
+
"euler_at_final": False,
|
1588
|
+
"final_sigmas_type": "zero",
|
1589
|
+
"lower_order_final": True,
|
1590
|
+
"num_train_timesteps": 1000,
|
1591
|
+
"prediction_type": "epsilon",
|
1592
|
+
"rho": 7.0,
|
1593
|
+
"sample_max_value": 1.0,
|
1594
|
+
"sigma_data": 0.5,
|
1595
|
+
"sigma_max": 80.0,
|
1596
|
+
"sigma_min": 0.002,
|
1597
|
+
"solver_order": 2,
|
1598
|
+
"solver_type": "midpoint",
|
1599
|
+
"thresholding": False,
|
1600
|
+
}
|
1601
|
+
scheduler = EDMDPMSolverMultistepScheduler(**scheduler_config)
|
1602
|
+
|
1374
1603
|
else:
|
1375
1604
|
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
1376
1605
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -215,7 +215,7 @@ class TextualInversionLoaderMixin:
|
|
215
215
|
embedding = state_dict["string_to_param"]["*"]
|
216
216
|
else:
|
217
217
|
raise ValueError(
|
218
|
-
f"Loaded state
|
218
|
+
f"Loaded state dictionary is incorrect: {state_dict}. \n\n"
|
219
219
|
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
|
220
220
|
" input key."
|
221
221
|
)
|
@@ -457,6 +457,8 @@ class TextualInversionLoaderMixin:
|
|
457
457
|
def unload_textual_inversion(
|
458
458
|
self,
|
459
459
|
tokens: Optional[Union[str, List[str]]] = None,
|
460
|
+
tokenizer: Optional["PreTrainedTokenizer"] = None,
|
461
|
+
text_encoder: Optional["PreTrainedModel"] = None,
|
460
462
|
):
|
461
463
|
r"""
|
462
464
|
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
|
@@ -481,11 +483,28 @@ class TextualInversionLoaderMixin:
|
|
481
483
|
|
482
484
|
# Remove just one token
|
483
485
|
pipeline.unload_textual_inversion("<moe-bius>")
|
486
|
+
|
487
|
+
# Example 3: unload from SDXL
|
488
|
+
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
489
|
+
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
|
490
|
+
|
491
|
+
# load embeddings to the text encoders
|
492
|
+
state_dict = load_file(embedding_path)
|
493
|
+
|
494
|
+
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
495
|
+
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
496
|
+
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
497
|
+
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
498
|
+
|
499
|
+
# Unload explicitly from both text encoders abd tokenizers
|
500
|
+
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
501
|
+
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
502
|
+
|
484
503
|
```
|
485
504
|
"""
|
486
505
|
|
487
|
-
tokenizer = getattr(self, "tokenizer", None)
|
488
|
-
text_encoder = getattr(self, "text_encoder", None)
|
506
|
+
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
507
|
+
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
489
508
|
|
490
509
|
# Get textual inversion tokens and ids
|
491
510
|
token_ids = []
|