diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -22,15 +22,19 @@ from pathlib import Path
|
|
22
22
|
from typing import Any, Dict, List, Optional, Union
|
23
23
|
|
24
24
|
import torch
|
25
|
-
from huggingface_hub import
|
26
|
-
|
27
|
-
)
|
25
|
+
from huggingface_hub import model_info
|
26
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
28
27
|
from packaging import version
|
29
28
|
|
29
|
+
from .. import __version__
|
30
30
|
from ..utils import (
|
31
|
+
FLAX_WEIGHTS_NAME,
|
32
|
+
ONNX_EXTERNAL_WEIGHTS_NAME,
|
33
|
+
ONNX_WEIGHTS_NAME,
|
31
34
|
SAFETENSORS_WEIGHTS_NAME,
|
32
35
|
WEIGHTS_NAME,
|
33
36
|
get_class_from_dynamic_module,
|
37
|
+
is_accelerate_available,
|
34
38
|
is_peft_available,
|
35
39
|
is_transformers_available,
|
36
40
|
logging,
|
@@ -44,9 +48,12 @@ if is_transformers_available():
|
|
44
48
|
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
45
49
|
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
46
50
|
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
47
|
-
from huggingface_hub.utils import validate_hf_hub_args
|
48
51
|
|
49
|
-
|
52
|
+
if is_accelerate_available():
|
53
|
+
import accelerate
|
54
|
+
from accelerate import dispatch_model
|
55
|
+
from accelerate.hooks import remove_hook_from_module
|
56
|
+
from accelerate.utils import compute_module_sizes, get_max_memory
|
50
57
|
|
51
58
|
|
52
59
|
INDEX_FILE = "diffusion_pytorch_model.bin"
|
@@ -292,6 +299,39 @@ def get_class_obj_and_candidates(
|
|
292
299
|
return class_obj, class_candidates
|
293
300
|
|
294
301
|
|
302
|
+
def _get_custom_pipeline_class(
|
303
|
+
custom_pipeline,
|
304
|
+
repo_id=None,
|
305
|
+
hub_revision=None,
|
306
|
+
class_name=None,
|
307
|
+
cache_dir=None,
|
308
|
+
revision=None,
|
309
|
+
):
|
310
|
+
if custom_pipeline.endswith(".py"):
|
311
|
+
path = Path(custom_pipeline)
|
312
|
+
# decompose into folder & file
|
313
|
+
file_name = path.name
|
314
|
+
custom_pipeline = path.parent.absolute()
|
315
|
+
elif repo_id is not None:
|
316
|
+
file_name = f"{custom_pipeline}.py"
|
317
|
+
custom_pipeline = repo_id
|
318
|
+
else:
|
319
|
+
file_name = CUSTOM_PIPELINE_FILE_NAME
|
320
|
+
|
321
|
+
if repo_id is not None and hub_revision is not None:
|
322
|
+
# if we load the pipeline code from the Hub
|
323
|
+
# make sure to overwrite the `revision`
|
324
|
+
revision = hub_revision
|
325
|
+
|
326
|
+
return get_class_from_dynamic_module(
|
327
|
+
custom_pipeline,
|
328
|
+
module_file=file_name,
|
329
|
+
class_name=class_name,
|
330
|
+
cache_dir=cache_dir,
|
331
|
+
revision=revision,
|
332
|
+
)
|
333
|
+
|
334
|
+
|
295
335
|
def _get_pipeline_class(
|
296
336
|
class_obj,
|
297
337
|
config=None,
|
@@ -304,25 +344,10 @@ def _get_pipeline_class(
|
|
304
344
|
revision=None,
|
305
345
|
):
|
306
346
|
if custom_pipeline is not None:
|
307
|
-
|
308
|
-
path = Path(custom_pipeline)
|
309
|
-
# decompose into folder & file
|
310
|
-
file_name = path.name
|
311
|
-
custom_pipeline = path.parent.absolute()
|
312
|
-
elif repo_id is not None:
|
313
|
-
file_name = f"{custom_pipeline}.py"
|
314
|
-
custom_pipeline = repo_id
|
315
|
-
else:
|
316
|
-
file_name = CUSTOM_PIPELINE_FILE_NAME
|
317
|
-
|
318
|
-
if repo_id is not None and hub_revision is not None:
|
319
|
-
# if we load the pipeline code from the Hub
|
320
|
-
# make sure to overwrite the `revision`
|
321
|
-
revision = hub_revision
|
322
|
-
|
323
|
-
return get_class_from_dynamic_module(
|
347
|
+
return _get_custom_pipeline_class(
|
324
348
|
custom_pipeline,
|
325
|
-
|
349
|
+
repo_id=repo_id,
|
350
|
+
hub_revision=hub_revision,
|
326
351
|
class_name=class_name,
|
327
352
|
cache_dir=cache_dir,
|
328
353
|
revision=revision,
|
@@ -358,6 +383,209 @@ def _get_pipeline_class(
|
|
358
383
|
return pipeline_cls
|
359
384
|
|
360
385
|
|
386
|
+
def _load_empty_model(
|
387
|
+
library_name: str,
|
388
|
+
class_name: str,
|
389
|
+
importable_classes: List[Any],
|
390
|
+
pipelines: Any,
|
391
|
+
is_pipeline_module: bool,
|
392
|
+
name: str,
|
393
|
+
torch_dtype: Union[str, torch.dtype],
|
394
|
+
cached_folder: Union[str, os.PathLike],
|
395
|
+
**kwargs,
|
396
|
+
):
|
397
|
+
# retrieve class objects.
|
398
|
+
class_obj, _ = get_class_obj_and_candidates(
|
399
|
+
library_name,
|
400
|
+
class_name,
|
401
|
+
importable_classes,
|
402
|
+
pipelines,
|
403
|
+
is_pipeline_module,
|
404
|
+
component_name=name,
|
405
|
+
cache_dir=cached_folder,
|
406
|
+
)
|
407
|
+
|
408
|
+
if is_transformers_available():
|
409
|
+
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
410
|
+
else:
|
411
|
+
transformers_version = "N/A"
|
412
|
+
|
413
|
+
# Determine library.
|
414
|
+
is_transformers_model = (
|
415
|
+
is_transformers_available()
|
416
|
+
and issubclass(class_obj, PreTrainedModel)
|
417
|
+
and transformers_version >= version.parse("4.20.0")
|
418
|
+
)
|
419
|
+
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
420
|
+
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
421
|
+
|
422
|
+
model = None
|
423
|
+
config_path = cached_folder
|
424
|
+
user_agent = {
|
425
|
+
"diffusers": __version__,
|
426
|
+
"file_type": "model",
|
427
|
+
"framework": "pytorch",
|
428
|
+
}
|
429
|
+
|
430
|
+
if is_diffusers_model:
|
431
|
+
# Load config and then the model on meta.
|
432
|
+
config, unused_kwargs, commit_hash = class_obj.load_config(
|
433
|
+
os.path.join(config_path, name),
|
434
|
+
cache_dir=cached_folder,
|
435
|
+
return_unused_kwargs=True,
|
436
|
+
return_commit_hash=True,
|
437
|
+
force_download=kwargs.pop("force_download", False),
|
438
|
+
resume_download=kwargs.pop("resume_download", None),
|
439
|
+
proxies=kwargs.pop("proxies", None),
|
440
|
+
local_files_only=kwargs.pop("local_files_only", False),
|
441
|
+
token=kwargs.pop("token", None),
|
442
|
+
revision=kwargs.pop("revision", None),
|
443
|
+
subfolder=kwargs.pop("subfolder", None),
|
444
|
+
user_agent=user_agent,
|
445
|
+
)
|
446
|
+
with accelerate.init_empty_weights():
|
447
|
+
model = class_obj.from_config(config, **unused_kwargs)
|
448
|
+
elif is_transformers_model:
|
449
|
+
config_class = getattr(class_obj, "config_class", None)
|
450
|
+
if config_class is None:
|
451
|
+
raise ValueError("`config_class` cannot be None. Please double-check the model.")
|
452
|
+
|
453
|
+
config = config_class.from_pretrained(
|
454
|
+
cached_folder,
|
455
|
+
subfolder=name,
|
456
|
+
force_download=kwargs.pop("force_download", False),
|
457
|
+
resume_download=kwargs.pop("resume_download", None),
|
458
|
+
proxies=kwargs.pop("proxies", None),
|
459
|
+
local_files_only=kwargs.pop("local_files_only", False),
|
460
|
+
token=kwargs.pop("token", None),
|
461
|
+
revision=kwargs.pop("revision", None),
|
462
|
+
user_agent=user_agent,
|
463
|
+
)
|
464
|
+
with accelerate.init_empty_weights():
|
465
|
+
model = class_obj(config)
|
466
|
+
|
467
|
+
if model is not None:
|
468
|
+
model = model.to(dtype=torch_dtype)
|
469
|
+
return model
|
470
|
+
|
471
|
+
|
472
|
+
def _assign_components_to_devices(
|
473
|
+
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
|
474
|
+
):
|
475
|
+
device_ids = list(device_memory.keys())
|
476
|
+
device_cycle = device_ids + device_ids[::-1]
|
477
|
+
device_memory = device_memory.copy()
|
478
|
+
|
479
|
+
device_id_component_mapping = {}
|
480
|
+
current_device_index = 0
|
481
|
+
for component in module_sizes:
|
482
|
+
device_id = device_cycle[current_device_index % len(device_cycle)]
|
483
|
+
component_memory = module_sizes[component]
|
484
|
+
curr_device_memory = device_memory[device_id]
|
485
|
+
|
486
|
+
# If the GPU doesn't fit the current component offload to the CPU.
|
487
|
+
if component_memory > curr_device_memory:
|
488
|
+
device_id_component_mapping["cpu"] = [component]
|
489
|
+
else:
|
490
|
+
if device_id not in device_id_component_mapping:
|
491
|
+
device_id_component_mapping[device_id] = [component]
|
492
|
+
else:
|
493
|
+
device_id_component_mapping[device_id].append(component)
|
494
|
+
|
495
|
+
# Update the device memory.
|
496
|
+
device_memory[device_id] -= component_memory
|
497
|
+
current_device_index += 1
|
498
|
+
|
499
|
+
return device_id_component_mapping
|
500
|
+
|
501
|
+
|
502
|
+
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
|
503
|
+
# To avoid circular import problem.
|
504
|
+
from diffusers import pipelines
|
505
|
+
|
506
|
+
torch_dtype = kwargs.get("torch_dtype", torch.float32)
|
507
|
+
|
508
|
+
# Load each module in the pipeline on a meta device so that we can derive the device map.
|
509
|
+
init_empty_modules = {}
|
510
|
+
for name, (library_name, class_name) in init_dict.items():
|
511
|
+
if class_name.startswith("Flax"):
|
512
|
+
raise ValueError("Flax pipelines are not supported with `device_map`.")
|
513
|
+
|
514
|
+
# Define all importable classes
|
515
|
+
is_pipeline_module = hasattr(pipelines, library_name)
|
516
|
+
importable_classes = ALL_IMPORTABLE_CLASSES
|
517
|
+
loaded_sub_model = None
|
518
|
+
|
519
|
+
# Use passed sub model or load class_name from library_name
|
520
|
+
if name in passed_class_obj:
|
521
|
+
# if the model is in a pipeline module, then we load it from the pipeline
|
522
|
+
# check that passed_class_obj has correct parent class
|
523
|
+
maybe_raise_or_warn(
|
524
|
+
library_name,
|
525
|
+
library,
|
526
|
+
class_name,
|
527
|
+
importable_classes,
|
528
|
+
passed_class_obj,
|
529
|
+
name,
|
530
|
+
is_pipeline_module,
|
531
|
+
)
|
532
|
+
with accelerate.init_empty_weights():
|
533
|
+
loaded_sub_model = passed_class_obj[name]
|
534
|
+
|
535
|
+
else:
|
536
|
+
loaded_sub_model = _load_empty_model(
|
537
|
+
library_name=library_name,
|
538
|
+
class_name=class_name,
|
539
|
+
importable_classes=importable_classes,
|
540
|
+
pipelines=pipelines,
|
541
|
+
is_pipeline_module=is_pipeline_module,
|
542
|
+
pipeline_class=pipeline_class,
|
543
|
+
name=name,
|
544
|
+
torch_dtype=torch_dtype,
|
545
|
+
cached_folder=kwargs.get("cached_folder", None),
|
546
|
+
force_download=kwargs.get("force_download", None),
|
547
|
+
resume_download=kwargs.get("resume_download", None),
|
548
|
+
proxies=kwargs.get("proxies", None),
|
549
|
+
local_files_only=kwargs.get("local_files_only", None),
|
550
|
+
token=kwargs.get("token", None),
|
551
|
+
revision=kwargs.get("revision", None),
|
552
|
+
)
|
553
|
+
|
554
|
+
if loaded_sub_model is not None:
|
555
|
+
init_empty_modules[name] = loaded_sub_model
|
556
|
+
|
557
|
+
# determine device map
|
558
|
+
# Obtain a sorted dictionary for mapping the model-level components
|
559
|
+
# to their sizes.
|
560
|
+
module_sizes = {
|
561
|
+
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
|
562
|
+
for module_name, module in init_empty_modules.items()
|
563
|
+
if isinstance(module, torch.nn.Module)
|
564
|
+
}
|
565
|
+
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
|
566
|
+
|
567
|
+
# Obtain maximum memory available per device (GPUs only).
|
568
|
+
max_memory = get_max_memory(max_memory)
|
569
|
+
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
|
570
|
+
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
|
571
|
+
|
572
|
+
# Obtain a dictionary mapping the model-level components to the available
|
573
|
+
# devices based on the maximum memory and the model sizes.
|
574
|
+
final_device_map = None
|
575
|
+
if len(max_memory) > 0:
|
576
|
+
device_id_component_mapping = _assign_components_to_devices(
|
577
|
+
module_sizes, max_memory, device_mapping_strategy=device_map
|
578
|
+
)
|
579
|
+
|
580
|
+
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
|
581
|
+
final_device_map = {}
|
582
|
+
for device_id, components in device_id_component_mapping.items():
|
583
|
+
for component in components:
|
584
|
+
final_device_map[component] = device_id
|
585
|
+
|
586
|
+
return final_device_map
|
587
|
+
|
588
|
+
|
361
589
|
def load_sub_model(
|
362
590
|
library_name: str,
|
363
591
|
class_name: str,
|
@@ -381,6 +609,7 @@ def load_sub_model(
|
|
381
609
|
):
|
382
610
|
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
383
611
|
# retrieve class candidates
|
612
|
+
|
384
613
|
class_obj, class_candidates = get_class_obj_and_candidates(
|
385
614
|
library_name,
|
386
615
|
class_name,
|
@@ -475,6 +704,22 @@ def load_sub_model(
|
|
475
704
|
# else load from the root directory
|
476
705
|
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
477
706
|
|
707
|
+
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
708
|
+
# remove hooks
|
709
|
+
remove_hook_from_module(loaded_sub_model, recurse=True)
|
710
|
+
needs_offloading_to_cpu = device_map[""] == "cpu"
|
711
|
+
|
712
|
+
if needs_offloading_to_cpu:
|
713
|
+
dispatch_model(
|
714
|
+
loaded_sub_model,
|
715
|
+
state_dict=loaded_sub_model.state_dict(),
|
716
|
+
device_map=device_map,
|
717
|
+
force_hooks=True,
|
718
|
+
main_device=0,
|
719
|
+
)
|
720
|
+
else:
|
721
|
+
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
|
722
|
+
|
478
723
|
return loaded_sub_model
|
479
724
|
|
480
725
|
|