diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +28 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +278 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
- diffusers-0.27.0.dist-info/RECORD +399 -0
- diffusers-0.26.2.dist-info/RECORD +0 -384
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,508 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
|
17
|
+
import importlib
|
18
|
+
import os
|
19
|
+
import re
|
20
|
+
import warnings
|
21
|
+
from pathlib import Path
|
22
|
+
from typing import Any, Dict, List, Optional, Union
|
23
|
+
|
24
|
+
import torch
|
25
|
+
from huggingface_hub import (
|
26
|
+
model_info,
|
27
|
+
)
|
28
|
+
from packaging import version
|
29
|
+
|
30
|
+
from ..utils import (
|
31
|
+
SAFETENSORS_WEIGHTS_NAME,
|
32
|
+
WEIGHTS_NAME,
|
33
|
+
get_class_from_dynamic_module,
|
34
|
+
is_peft_available,
|
35
|
+
is_transformers_available,
|
36
|
+
logging,
|
37
|
+
)
|
38
|
+
from ..utils.torch_utils import is_compiled_module
|
39
|
+
|
40
|
+
|
41
|
+
if is_transformers_available():
|
42
|
+
import transformers
|
43
|
+
from transformers import PreTrainedModel
|
44
|
+
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
45
|
+
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
46
|
+
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
47
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
48
|
+
|
49
|
+
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
50
|
+
|
51
|
+
|
52
|
+
INDEX_FILE = "diffusion_pytorch_model.bin"
|
53
|
+
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
|
54
|
+
DUMMY_MODULES_FOLDER = "diffusers.utils"
|
55
|
+
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
56
|
+
CONNECTED_PIPES_KEYS = ["prior"]
|
57
|
+
|
58
|
+
logger = logging.get_logger(__name__)
|
59
|
+
|
60
|
+
LOADABLE_CLASSES = {
|
61
|
+
"diffusers": {
|
62
|
+
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
63
|
+
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
64
|
+
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
65
|
+
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
66
|
+
},
|
67
|
+
"transformers": {
|
68
|
+
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
69
|
+
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
70
|
+
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
71
|
+
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
72
|
+
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
73
|
+
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
74
|
+
},
|
75
|
+
"onnxruntime.training": {
|
76
|
+
"ORTModule": ["save_pretrained", "from_pretrained"],
|
77
|
+
},
|
78
|
+
}
|
79
|
+
|
80
|
+
ALL_IMPORTABLE_CLASSES = {}
|
81
|
+
for library in LOADABLE_CLASSES:
|
82
|
+
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
83
|
+
|
84
|
+
|
85
|
+
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
86
|
+
"""
|
87
|
+
Checking for safetensors compatibility:
|
88
|
+
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
89
|
+
files to know which safetensors files are needed.
|
90
|
+
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
91
|
+
|
92
|
+
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
93
|
+
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
94
|
+
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
95
|
+
extension is replaced with ".safetensors"
|
96
|
+
"""
|
97
|
+
pt_filenames = []
|
98
|
+
|
99
|
+
sf_filenames = set()
|
100
|
+
|
101
|
+
passed_components = passed_components or []
|
102
|
+
|
103
|
+
for filename in filenames:
|
104
|
+
_, extension = os.path.splitext(filename)
|
105
|
+
|
106
|
+
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
107
|
+
continue
|
108
|
+
|
109
|
+
if extension == ".bin":
|
110
|
+
pt_filenames.append(os.path.normpath(filename))
|
111
|
+
elif extension == ".safetensors":
|
112
|
+
sf_filenames.add(os.path.normpath(filename))
|
113
|
+
|
114
|
+
for filename in pt_filenames:
|
115
|
+
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
|
116
|
+
path, filename = os.path.split(filename)
|
117
|
+
filename, extension = os.path.splitext(filename)
|
118
|
+
|
119
|
+
if filename.startswith("pytorch_model"):
|
120
|
+
filename = filename.replace("pytorch_model", "model")
|
121
|
+
else:
|
122
|
+
filename = filename
|
123
|
+
|
124
|
+
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
125
|
+
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
126
|
+
if expected_sf_filename not in sf_filenames:
|
127
|
+
logger.warning(f"{expected_sf_filename} not found")
|
128
|
+
return False
|
129
|
+
|
130
|
+
return True
|
131
|
+
|
132
|
+
|
133
|
+
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
134
|
+
weight_names = [
|
135
|
+
WEIGHTS_NAME,
|
136
|
+
SAFETENSORS_WEIGHTS_NAME,
|
137
|
+
FLAX_WEIGHTS_NAME,
|
138
|
+
ONNX_WEIGHTS_NAME,
|
139
|
+
ONNX_EXTERNAL_WEIGHTS_NAME,
|
140
|
+
]
|
141
|
+
|
142
|
+
if is_transformers_available():
|
143
|
+
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
144
|
+
|
145
|
+
# model_pytorch, diffusion_model_pytorch, ...
|
146
|
+
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
147
|
+
# .bin, .safetensors, ...
|
148
|
+
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
149
|
+
# -00001-of-00002
|
150
|
+
transformers_index_format = r"\d{5}-of-\d{5}"
|
151
|
+
|
152
|
+
if variant is not None:
|
153
|
+
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
154
|
+
variant_file_re = re.compile(
|
155
|
+
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
156
|
+
)
|
157
|
+
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
158
|
+
variant_index_re = re.compile(
|
159
|
+
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
160
|
+
)
|
161
|
+
|
162
|
+
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
163
|
+
non_variant_file_re = re.compile(
|
164
|
+
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
165
|
+
)
|
166
|
+
# `text_encoder/pytorch_model.bin.index.json`
|
167
|
+
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
168
|
+
|
169
|
+
if variant is not None:
|
170
|
+
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
171
|
+
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
172
|
+
variant_filenames = variant_weights | variant_indexes
|
173
|
+
else:
|
174
|
+
variant_filenames = set()
|
175
|
+
|
176
|
+
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
177
|
+
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
178
|
+
non_variant_filenames = non_variant_weights | non_variant_indexes
|
179
|
+
|
180
|
+
# all variant filenames will be used by default
|
181
|
+
usable_filenames = set(variant_filenames)
|
182
|
+
|
183
|
+
def convert_to_variant(filename):
|
184
|
+
if "index" in filename:
|
185
|
+
variant_filename = filename.replace("index", f"index.{variant}")
|
186
|
+
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
187
|
+
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
188
|
+
else:
|
189
|
+
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
190
|
+
return variant_filename
|
191
|
+
|
192
|
+
for f in non_variant_filenames:
|
193
|
+
variant_filename = convert_to_variant(f)
|
194
|
+
if variant_filename not in usable_filenames:
|
195
|
+
usable_filenames.add(f)
|
196
|
+
|
197
|
+
return usable_filenames, variant_filenames
|
198
|
+
|
199
|
+
|
200
|
+
@validate_hf_hub_args
|
201
|
+
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
|
202
|
+
info = model_info(
|
203
|
+
pretrained_model_name_or_path,
|
204
|
+
token=token,
|
205
|
+
revision=None,
|
206
|
+
)
|
207
|
+
filenames = {sibling.rfilename for sibling in info.siblings}
|
208
|
+
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
209
|
+
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
210
|
+
|
211
|
+
if set(model_filenames).issubset(set(comp_model_filenames)):
|
212
|
+
warnings.warn(
|
213
|
+
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
214
|
+
FutureWarning,
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
warnings.warn(
|
218
|
+
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
219
|
+
FutureWarning,
|
220
|
+
)
|
221
|
+
|
222
|
+
|
223
|
+
def _unwrap_model(model):
|
224
|
+
"""Unwraps a model."""
|
225
|
+
if is_compiled_module(model):
|
226
|
+
model = model._orig_mod
|
227
|
+
|
228
|
+
if is_peft_available():
|
229
|
+
from peft import PeftModel
|
230
|
+
|
231
|
+
if isinstance(model, PeftModel):
|
232
|
+
model = model.base_model.model
|
233
|
+
|
234
|
+
return model
|
235
|
+
|
236
|
+
|
237
|
+
def maybe_raise_or_warn(
|
238
|
+
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
239
|
+
):
|
240
|
+
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
241
|
+
if not is_pipeline_module:
|
242
|
+
library = importlib.import_module(library_name)
|
243
|
+
class_obj = getattr(library, class_name)
|
244
|
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
245
|
+
|
246
|
+
expected_class_obj = None
|
247
|
+
for class_name, class_candidate in class_candidates.items():
|
248
|
+
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
249
|
+
expected_class_obj = class_candidate
|
250
|
+
|
251
|
+
# Dynamo wraps the original model in a private class.
|
252
|
+
# I didn't find a public API to get the original class.
|
253
|
+
sub_model = passed_class_obj[name]
|
254
|
+
unwrapped_sub_model = _unwrap_model(sub_model)
|
255
|
+
model_cls = unwrapped_sub_model.__class__
|
256
|
+
|
257
|
+
if not issubclass(model_cls, expected_class_obj):
|
258
|
+
raise ValueError(
|
259
|
+
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
260
|
+
)
|
261
|
+
else:
|
262
|
+
logger.warning(
|
263
|
+
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
264
|
+
" has the correct type"
|
265
|
+
)
|
266
|
+
|
267
|
+
|
268
|
+
def get_class_obj_and_candidates(
|
269
|
+
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
270
|
+
):
|
271
|
+
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
272
|
+
component_folder = os.path.join(cache_dir, component_name)
|
273
|
+
|
274
|
+
if is_pipeline_module:
|
275
|
+
pipeline_module = getattr(pipelines, library_name)
|
276
|
+
|
277
|
+
class_obj = getattr(pipeline_module, class_name)
|
278
|
+
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
279
|
+
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
280
|
+
# load custom component
|
281
|
+
class_obj = get_class_from_dynamic_module(
|
282
|
+
component_folder, module_file=library_name + ".py", class_name=class_name
|
283
|
+
)
|
284
|
+
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
285
|
+
else:
|
286
|
+
# else we just import it from the library.
|
287
|
+
library = importlib.import_module(library_name)
|
288
|
+
|
289
|
+
class_obj = getattr(library, class_name)
|
290
|
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
291
|
+
|
292
|
+
return class_obj, class_candidates
|
293
|
+
|
294
|
+
|
295
|
+
def _get_pipeline_class(
|
296
|
+
class_obj,
|
297
|
+
config=None,
|
298
|
+
load_connected_pipeline=False,
|
299
|
+
custom_pipeline=None,
|
300
|
+
repo_id=None,
|
301
|
+
hub_revision=None,
|
302
|
+
class_name=None,
|
303
|
+
cache_dir=None,
|
304
|
+
revision=None,
|
305
|
+
):
|
306
|
+
if custom_pipeline is not None:
|
307
|
+
if custom_pipeline.endswith(".py"):
|
308
|
+
path = Path(custom_pipeline)
|
309
|
+
# decompose into folder & file
|
310
|
+
file_name = path.name
|
311
|
+
custom_pipeline = path.parent.absolute()
|
312
|
+
elif repo_id is not None:
|
313
|
+
file_name = f"{custom_pipeline}.py"
|
314
|
+
custom_pipeline = repo_id
|
315
|
+
else:
|
316
|
+
file_name = CUSTOM_PIPELINE_FILE_NAME
|
317
|
+
|
318
|
+
if repo_id is not None and hub_revision is not None:
|
319
|
+
# if we load the pipeline code from the Hub
|
320
|
+
# make sure to overwrite the `revision`
|
321
|
+
revision = hub_revision
|
322
|
+
|
323
|
+
return get_class_from_dynamic_module(
|
324
|
+
custom_pipeline,
|
325
|
+
module_file=file_name,
|
326
|
+
class_name=class_name,
|
327
|
+
cache_dir=cache_dir,
|
328
|
+
revision=revision,
|
329
|
+
)
|
330
|
+
|
331
|
+
if class_obj.__name__ != "DiffusionPipeline":
|
332
|
+
return class_obj
|
333
|
+
|
334
|
+
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
335
|
+
class_name = class_name or config["_class_name"]
|
336
|
+
if not class_name:
|
337
|
+
raise ValueError(
|
338
|
+
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
339
|
+
)
|
340
|
+
|
341
|
+
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
342
|
+
|
343
|
+
pipeline_cls = getattr(diffusers_module, class_name)
|
344
|
+
|
345
|
+
if load_connected_pipeline:
|
346
|
+
from .auto_pipeline import _get_connected_pipeline
|
347
|
+
|
348
|
+
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
|
349
|
+
if connected_pipeline_cls is not None:
|
350
|
+
logger.info(
|
351
|
+
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
|
352
|
+
)
|
353
|
+
else:
|
354
|
+
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
|
355
|
+
|
356
|
+
pipeline_cls = connected_pipeline_cls or pipeline_cls
|
357
|
+
|
358
|
+
return pipeline_cls
|
359
|
+
|
360
|
+
|
361
|
+
def load_sub_model(
|
362
|
+
library_name: str,
|
363
|
+
class_name: str,
|
364
|
+
importable_classes: List[Any],
|
365
|
+
pipelines: Any,
|
366
|
+
is_pipeline_module: bool,
|
367
|
+
pipeline_class: Any,
|
368
|
+
torch_dtype: torch.dtype,
|
369
|
+
provider: Any,
|
370
|
+
sess_options: Any,
|
371
|
+
device_map: Optional[Union[Dict[str, torch.device], str]],
|
372
|
+
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
|
373
|
+
offload_folder: Optional[Union[str, os.PathLike]],
|
374
|
+
offload_state_dict: bool,
|
375
|
+
model_variants: Dict[str, str],
|
376
|
+
name: str,
|
377
|
+
from_flax: bool,
|
378
|
+
variant: str,
|
379
|
+
low_cpu_mem_usage: bool,
|
380
|
+
cached_folder: Union[str, os.PathLike],
|
381
|
+
):
|
382
|
+
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
383
|
+
# retrieve class candidates
|
384
|
+
class_obj, class_candidates = get_class_obj_and_candidates(
|
385
|
+
library_name,
|
386
|
+
class_name,
|
387
|
+
importable_classes,
|
388
|
+
pipelines,
|
389
|
+
is_pipeline_module,
|
390
|
+
component_name=name,
|
391
|
+
cache_dir=cached_folder,
|
392
|
+
)
|
393
|
+
|
394
|
+
load_method_name = None
|
395
|
+
# retrieve load method name
|
396
|
+
for class_name, class_candidate in class_candidates.items():
|
397
|
+
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
398
|
+
load_method_name = importable_classes[class_name][1]
|
399
|
+
|
400
|
+
# if load method name is None, then we have a dummy module -> raise Error
|
401
|
+
if load_method_name is None:
|
402
|
+
none_module = class_obj.__module__
|
403
|
+
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
404
|
+
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
405
|
+
)
|
406
|
+
if is_dummy_path and "dummy" in none_module:
|
407
|
+
# call class_obj for nice error message of missing requirements
|
408
|
+
class_obj()
|
409
|
+
|
410
|
+
raise ValueError(
|
411
|
+
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
412
|
+
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
413
|
+
)
|
414
|
+
|
415
|
+
load_method = getattr(class_obj, load_method_name)
|
416
|
+
|
417
|
+
# add kwargs to loading method
|
418
|
+
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
419
|
+
loading_kwargs = {}
|
420
|
+
if issubclass(class_obj, torch.nn.Module):
|
421
|
+
loading_kwargs["torch_dtype"] = torch_dtype
|
422
|
+
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
423
|
+
loading_kwargs["provider"] = provider
|
424
|
+
loading_kwargs["sess_options"] = sess_options
|
425
|
+
|
426
|
+
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
427
|
+
|
428
|
+
if is_transformers_available():
|
429
|
+
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
430
|
+
else:
|
431
|
+
transformers_version = "N/A"
|
432
|
+
|
433
|
+
is_transformers_model = (
|
434
|
+
is_transformers_available()
|
435
|
+
and issubclass(class_obj, PreTrainedModel)
|
436
|
+
and transformers_version >= version.parse("4.20.0")
|
437
|
+
)
|
438
|
+
|
439
|
+
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
440
|
+
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
441
|
+
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
442
|
+
if is_diffusers_model or is_transformers_model:
|
443
|
+
loading_kwargs["device_map"] = device_map
|
444
|
+
loading_kwargs["max_memory"] = max_memory
|
445
|
+
loading_kwargs["offload_folder"] = offload_folder
|
446
|
+
loading_kwargs["offload_state_dict"] = offload_state_dict
|
447
|
+
loading_kwargs["variant"] = model_variants.pop(name, None)
|
448
|
+
|
449
|
+
if from_flax:
|
450
|
+
loading_kwargs["from_flax"] = True
|
451
|
+
|
452
|
+
# the following can be deleted once the minimum required `transformers` version
|
453
|
+
# is higher than 4.27
|
454
|
+
if (
|
455
|
+
is_transformers_model
|
456
|
+
and loading_kwargs["variant"] is not None
|
457
|
+
and transformers_version < version.parse("4.27.0")
|
458
|
+
):
|
459
|
+
raise ImportError(
|
460
|
+
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
461
|
+
)
|
462
|
+
elif is_transformers_model and loading_kwargs["variant"] is None:
|
463
|
+
loading_kwargs.pop("variant")
|
464
|
+
|
465
|
+
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
466
|
+
if not (from_flax and is_transformers_model):
|
467
|
+
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
468
|
+
else:
|
469
|
+
loading_kwargs["low_cpu_mem_usage"] = False
|
470
|
+
|
471
|
+
# check if the module is in a subdirectory
|
472
|
+
if os.path.isdir(os.path.join(cached_folder, name)):
|
473
|
+
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
474
|
+
else:
|
475
|
+
# else load from the root directory
|
476
|
+
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
477
|
+
|
478
|
+
return loaded_sub_model
|
479
|
+
|
480
|
+
|
481
|
+
def _fetch_class_library_tuple(module):
|
482
|
+
# import it here to avoid circular import
|
483
|
+
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
484
|
+
pipelines = getattr(diffusers_module, "pipelines")
|
485
|
+
|
486
|
+
# register the config from the original module, not the dynamo compiled one
|
487
|
+
not_compiled_module = _unwrap_model(module)
|
488
|
+
library = not_compiled_module.__module__.split(".")[0]
|
489
|
+
|
490
|
+
# check if the module is a pipeline module
|
491
|
+
module_path_items = not_compiled_module.__module__.split(".")
|
492
|
+
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
493
|
+
|
494
|
+
path = not_compiled_module.__module__.split(".")
|
495
|
+
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
496
|
+
|
497
|
+
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
498
|
+
# Or if it's a pipeline module, then the module is inside the pipeline
|
499
|
+
# folder so we set the library to module name.
|
500
|
+
if is_pipeline_module:
|
501
|
+
library = pipeline_dir
|
502
|
+
elif library not in LOADABLE_CLASSES:
|
503
|
+
library = not_compiled_module.__module__
|
504
|
+
|
505
|
+
# retrieve class_name
|
506
|
+
class_name = not_compiled_module.__class__.__name__
|
507
|
+
|
508
|
+
return (library, class_name)
|