diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +20 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +7 -3
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +2 -2
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +110 -4
- diffusers/loaders/autoencoder.py +7 -8
- diffusers/loaders/controlnet.py +17 -8
- diffusers/loaders/ip_adapter.py +86 -23
- diffusers/loaders/lora.py +105 -310
- diffusers/loaders/lora_conversion_utils.py +1 -1
- diffusers/loaders/peft.py +1 -1
- diffusers/loaders/single_file.py +51 -12
- diffusers/loaders/single_file_utils.py +274 -49
- diffusers/loaders/textual_inversion.py +23 -4
- diffusers/loaders/unet.py +195 -41
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +3 -1
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +26 -36
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +171 -114
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flax.py +1 -1
- diffusers/models/downsampling.py +8 -12
- diffusers/models/dual_transformer_2d.py +1 -1
- diffusers/models/embeddings.py +3 -4
- diffusers/models/embeddings_flax.py +1 -1
- diffusers/models/lora.py +33 -10
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +4 -6
- diffusers/models/normalization.py +1 -1
- diffusers/models/resnet.py +31 -58
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/t5_film_transformer.py +1 -1
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/transformer_temporal.py +1 -1
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/t5_film_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +29 -31
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unet_1d.py +1 -1
- diffusers/models/unet_1d_blocks.py +1 -1
- diffusers/models/unet_2d.py +1 -1
- diffusers/models/unet_2d_blocks.py +1 -1
- diffusers/models/unet_2d_condition.py +1 -1
- diffusers/models/unets/__init__.py +1 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +4 -4
- diffusers/models/unets/unet_2d_blocks.py +238 -98
- diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +420 -323
- diffusers/models/unets/unet_2d_condition_flax.py +21 -12
- diffusers/models/unets/unet_3d_blocks.py +50 -40
- diffusers/models/unets/unet_3d_condition.py +47 -8
- diffusers/models/unets/unet_i2vgen_xl.py +75 -30
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +48 -8
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +610 -0
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +10 -16
- diffusers/models/vae_flax.py +1 -1
- diffusers/models/vq_model.py +1 -1
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +26 -0
- diffusers/pipelines/amused/pipeline_amused.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
- diffusers/pipelines/animatediff/pipeline_output.py +7 -6
- diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
- diffusers/pipelines/auto_pipeline.py +7 -16
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/free_init_utils.py +184 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ledits_pp/__init__.py +55 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
- diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
- diffusers/pipelines/onnx_utils.py +1 -1
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
- diffusers/pipelines/pia/pipeline_pia.py +168 -327
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +508 -0
- diffusers/pipelines/pipeline_utils.py +188 -534
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/__init__.py +50 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
- diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
- diffusers/pipelines/unclip/text_proj.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
- diffusers/schedulers/__init__.py +7 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +42 -19
- diffusers/schedulers/scheduling_ddim.py +2 -4
- diffusers/schedulers/scheduling_ddim_flax.py +13 -5
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm.py +2 -4
- diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +46 -19
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
- diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
- diffusers/schedulers/scheduling_edm_euler.py +381 -0
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
- diffusers/schedulers/scheduling_euler_discrete.py +42 -17
- diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_heun_discrete.py +35 -35
- diffusers/schedulers/scheduling_ipndm.py +37 -11
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
- diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +38 -14
- diffusers/schedulers/scheduling_lms_discrete.py +43 -15
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +2 -4
- diffusers/schedulers/scheduling_pndm_flax.py +2 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +41 -9
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
- diffusers/schedulers/scheduling_tcd.py +686 -0
- diffusers/schedulers/scheduling_unclip.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +9 -2
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +3 -3
- diffusers/utils/hub_utils.py +60 -16
- diffusers/utils/import_utils.py +15 -1
- diffusers/utils/loading_utils.py +2 -0
- diffusers/utils/logging.py +1 -1
- diffusers/utils/model_card_template.md +24 -0
- diffusers/utils/outputs.py +14 -7
- diffusers/utils/peft_utils.py +1 -1
- diffusers/utils/state_dict_utils.py +1 -1
- diffusers/utils/testing_utils.py +2 -0
- diffusers/utils/torch_utils.py +1 -1
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
- diffusers-0.27.0.dist-info/RECORD +399 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
- diffusers-0.26.3.dist-info/RECORD +0 -384
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
3
3
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
4
|
#
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -19,7 +19,6 @@ import inspect
|
|
19
19
|
import os
|
20
20
|
import re
|
21
21
|
import sys
|
22
|
-
import warnings
|
23
22
|
from dataclasses import dataclass
|
24
23
|
from pathlib import Path
|
25
24
|
from typing import Any, Callable, Dict, List, Optional, Union
|
@@ -42,21 +41,20 @@ from tqdm.auto import tqdm
|
|
42
41
|
|
43
42
|
from .. import __version__
|
44
43
|
from ..configuration_utils import ConfigMixin
|
44
|
+
from ..models import AutoencoderKL
|
45
|
+
from ..models.attention_processor import FusedAttnProcessor2_0
|
45
46
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
46
47
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
47
48
|
from ..utils import (
|
48
49
|
CONFIG_NAME,
|
49
50
|
DEPRECATED_REVISION_ARGS,
|
50
|
-
SAFETENSORS_WEIGHTS_NAME,
|
51
|
-
WEIGHTS_NAME,
|
52
51
|
BaseOutput,
|
52
|
+
PushToHubMixin,
|
53
53
|
deprecate,
|
54
|
-
get_class_from_dynamic_module,
|
55
54
|
is_accelerate_available,
|
56
55
|
is_accelerate_version,
|
57
|
-
|
56
|
+
is_torch_npu_available,
|
58
57
|
is_torch_version,
|
59
|
-
is_transformers_available,
|
60
58
|
logging,
|
61
59
|
numpy_to_pil,
|
62
60
|
)
|
@@ -64,55 +62,37 @@ from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
|
64
62
|
from ..utils.torch_utils import is_compiled_module
|
65
63
|
|
66
64
|
|
67
|
-
if
|
68
|
-
import
|
69
|
-
from transformers import PreTrainedModel
|
70
|
-
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
71
|
-
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
72
|
-
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
65
|
+
if is_torch_npu_available():
|
66
|
+
import torch_npu # noqa: F401
|
73
67
|
|
74
|
-
|
68
|
+
|
69
|
+
from .pipeline_loading_utils import (
|
70
|
+
ALL_IMPORTABLE_CLASSES,
|
71
|
+
CONNECTED_PIPES_KEYS,
|
72
|
+
CUSTOM_PIPELINE_FILE_NAME,
|
73
|
+
LOADABLE_CLASSES,
|
74
|
+
_fetch_class_library_tuple,
|
75
|
+
_get_pipeline_class,
|
76
|
+
_unwrap_model,
|
77
|
+
is_safetensors_compatible,
|
78
|
+
load_sub_model,
|
79
|
+
maybe_raise_or_warn,
|
80
|
+
variant_compatible_siblings,
|
81
|
+
warn_deprecated_model_variant,
|
82
|
+
)
|
75
83
|
|
76
84
|
|
77
85
|
if is_accelerate_available():
|
78
86
|
import accelerate
|
79
87
|
|
80
88
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
|
85
|
-
CONNECTED_PIPES_KEYS = ["prior"]
|
86
|
-
|
89
|
+
LIBRARIES = []
|
90
|
+
for library in LOADABLE_CLASSES:
|
91
|
+
LIBRARIES.append(library)
|
87
92
|
|
88
93
|
logger = logging.get_logger(__name__)
|
89
94
|
|
90
95
|
|
91
|
-
LOADABLE_CLASSES = {
|
92
|
-
"diffusers": {
|
93
|
-
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
94
|
-
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
|
95
|
-
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
96
|
-
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
|
97
|
-
},
|
98
|
-
"transformers": {
|
99
|
-
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
100
|
-
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
101
|
-
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
102
|
-
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
103
|
-
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
104
|
-
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
105
|
-
},
|
106
|
-
"onnxruntime.training": {
|
107
|
-
"ORTModule": ["save_pretrained", "from_pretrained"],
|
108
|
-
},
|
109
|
-
}
|
110
|
-
|
111
|
-
ALL_IMPORTABLE_CLASSES = {}
|
112
|
-
for library in LOADABLE_CLASSES:
|
113
|
-
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
114
|
-
|
115
|
-
|
116
96
|
@dataclass
|
117
97
|
class ImagePipelineOutput(BaseOutput):
|
118
98
|
"""
|
@@ -140,432 +120,6 @@ class AudioPipelineOutput(BaseOutput):
|
|
140
120
|
audios: np.ndarray
|
141
121
|
|
142
122
|
|
143
|
-
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
144
|
-
"""
|
145
|
-
Checking for safetensors compatibility:
|
146
|
-
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
147
|
-
files to know which safetensors files are needed.
|
148
|
-
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
149
|
-
|
150
|
-
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
151
|
-
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
152
|
-
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
153
|
-
extension is replaced with ".safetensors"
|
154
|
-
"""
|
155
|
-
pt_filenames = []
|
156
|
-
|
157
|
-
sf_filenames = set()
|
158
|
-
|
159
|
-
passed_components = passed_components or []
|
160
|
-
|
161
|
-
for filename in filenames:
|
162
|
-
_, extension = os.path.splitext(filename)
|
163
|
-
|
164
|
-
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
165
|
-
continue
|
166
|
-
|
167
|
-
if extension == ".bin":
|
168
|
-
pt_filenames.append(os.path.normpath(filename))
|
169
|
-
elif extension == ".safetensors":
|
170
|
-
sf_filenames.add(os.path.normpath(filename))
|
171
|
-
|
172
|
-
for filename in pt_filenames:
|
173
|
-
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
|
174
|
-
path, filename = os.path.split(filename)
|
175
|
-
filename, extension = os.path.splitext(filename)
|
176
|
-
|
177
|
-
if filename.startswith("pytorch_model"):
|
178
|
-
filename = filename.replace("pytorch_model", "model")
|
179
|
-
else:
|
180
|
-
filename = filename
|
181
|
-
|
182
|
-
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
183
|
-
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
184
|
-
if expected_sf_filename not in sf_filenames:
|
185
|
-
logger.warning(f"{expected_sf_filename} not found")
|
186
|
-
return False
|
187
|
-
|
188
|
-
return True
|
189
|
-
|
190
|
-
|
191
|
-
def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]:
|
192
|
-
weight_names = [
|
193
|
-
WEIGHTS_NAME,
|
194
|
-
SAFETENSORS_WEIGHTS_NAME,
|
195
|
-
FLAX_WEIGHTS_NAME,
|
196
|
-
ONNX_WEIGHTS_NAME,
|
197
|
-
ONNX_EXTERNAL_WEIGHTS_NAME,
|
198
|
-
]
|
199
|
-
|
200
|
-
if is_transformers_available():
|
201
|
-
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
202
|
-
|
203
|
-
# model_pytorch, diffusion_model_pytorch, ...
|
204
|
-
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
205
|
-
# .bin, .safetensors, ...
|
206
|
-
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
207
|
-
# -00001-of-00002
|
208
|
-
transformers_index_format = r"\d{5}-of-\d{5}"
|
209
|
-
|
210
|
-
if variant is not None:
|
211
|
-
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors`
|
212
|
-
variant_file_re = re.compile(
|
213
|
-
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
214
|
-
)
|
215
|
-
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
216
|
-
variant_index_re = re.compile(
|
217
|
-
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
218
|
-
)
|
219
|
-
|
220
|
-
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
221
|
-
non_variant_file_re = re.compile(
|
222
|
-
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
223
|
-
)
|
224
|
-
# `text_encoder/pytorch_model.bin.index.json`
|
225
|
-
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
226
|
-
|
227
|
-
if variant is not None:
|
228
|
-
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
229
|
-
variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None}
|
230
|
-
variant_filenames = variant_weights | variant_indexes
|
231
|
-
else:
|
232
|
-
variant_filenames = set()
|
233
|
-
|
234
|
-
non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None}
|
235
|
-
non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None}
|
236
|
-
non_variant_filenames = non_variant_weights | non_variant_indexes
|
237
|
-
|
238
|
-
# all variant filenames will be used by default
|
239
|
-
usable_filenames = set(variant_filenames)
|
240
|
-
|
241
|
-
def convert_to_variant(filename):
|
242
|
-
if "index" in filename:
|
243
|
-
variant_filename = filename.replace("index", f"index.{variant}")
|
244
|
-
elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None:
|
245
|
-
variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}"
|
246
|
-
else:
|
247
|
-
variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
|
248
|
-
return variant_filename
|
249
|
-
|
250
|
-
for f in non_variant_filenames:
|
251
|
-
variant_filename = convert_to_variant(f)
|
252
|
-
if variant_filename not in usable_filenames:
|
253
|
-
usable_filenames.add(f)
|
254
|
-
|
255
|
-
return usable_filenames, variant_filenames
|
256
|
-
|
257
|
-
|
258
|
-
@validate_hf_hub_args
|
259
|
-
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
|
260
|
-
info = model_info(
|
261
|
-
pretrained_model_name_or_path,
|
262
|
-
token=token,
|
263
|
-
revision=None,
|
264
|
-
)
|
265
|
-
filenames = {sibling.rfilename for sibling in info.siblings}
|
266
|
-
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
|
267
|
-
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
|
268
|
-
|
269
|
-
if set(model_filenames).issubset(set(comp_model_filenames)):
|
270
|
-
warnings.warn(
|
271
|
-
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
272
|
-
FutureWarning,
|
273
|
-
)
|
274
|
-
else:
|
275
|
-
warnings.warn(
|
276
|
-
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.",
|
277
|
-
FutureWarning,
|
278
|
-
)
|
279
|
-
|
280
|
-
|
281
|
-
def _unwrap_model(model):
|
282
|
-
"""Unwraps a model."""
|
283
|
-
if is_compiled_module(model):
|
284
|
-
model = model._orig_mod
|
285
|
-
|
286
|
-
if is_peft_available():
|
287
|
-
from peft import PeftModel
|
288
|
-
|
289
|
-
if isinstance(model, PeftModel):
|
290
|
-
model = model.base_model.model
|
291
|
-
|
292
|
-
return model
|
293
|
-
|
294
|
-
|
295
|
-
def maybe_raise_or_warn(
|
296
|
-
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
297
|
-
):
|
298
|
-
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
299
|
-
if not is_pipeline_module:
|
300
|
-
library = importlib.import_module(library_name)
|
301
|
-
class_obj = getattr(library, class_name)
|
302
|
-
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
303
|
-
|
304
|
-
expected_class_obj = None
|
305
|
-
for class_name, class_candidate in class_candidates.items():
|
306
|
-
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
307
|
-
expected_class_obj = class_candidate
|
308
|
-
|
309
|
-
# Dynamo wraps the original model in a private class.
|
310
|
-
# I didn't find a public API to get the original class.
|
311
|
-
sub_model = passed_class_obj[name]
|
312
|
-
unwrapped_sub_model = _unwrap_model(sub_model)
|
313
|
-
model_cls = unwrapped_sub_model.__class__
|
314
|
-
|
315
|
-
if not issubclass(model_cls, expected_class_obj):
|
316
|
-
raise ValueError(
|
317
|
-
f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
|
318
|
-
)
|
319
|
-
else:
|
320
|
-
logger.warning(
|
321
|
-
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
322
|
-
" has the correct type"
|
323
|
-
)
|
324
|
-
|
325
|
-
|
326
|
-
def get_class_obj_and_candidates(
|
327
|
-
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
328
|
-
):
|
329
|
-
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
330
|
-
component_folder = os.path.join(cache_dir, component_name)
|
331
|
-
|
332
|
-
if is_pipeline_module:
|
333
|
-
pipeline_module = getattr(pipelines, library_name)
|
334
|
-
|
335
|
-
class_obj = getattr(pipeline_module, class_name)
|
336
|
-
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
337
|
-
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
338
|
-
# load custom component
|
339
|
-
class_obj = get_class_from_dynamic_module(
|
340
|
-
component_folder, module_file=library_name + ".py", class_name=class_name
|
341
|
-
)
|
342
|
-
class_candidates = {c: class_obj for c in importable_classes.keys()}
|
343
|
-
else:
|
344
|
-
# else we just import it from the library.
|
345
|
-
library = importlib.import_module(library_name)
|
346
|
-
|
347
|
-
class_obj = getattr(library, class_name)
|
348
|
-
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
349
|
-
|
350
|
-
return class_obj, class_candidates
|
351
|
-
|
352
|
-
|
353
|
-
def _get_pipeline_class(
|
354
|
-
class_obj,
|
355
|
-
config=None,
|
356
|
-
load_connected_pipeline=False,
|
357
|
-
custom_pipeline=None,
|
358
|
-
repo_id=None,
|
359
|
-
hub_revision=None,
|
360
|
-
class_name=None,
|
361
|
-
cache_dir=None,
|
362
|
-
revision=None,
|
363
|
-
):
|
364
|
-
if custom_pipeline is not None:
|
365
|
-
if custom_pipeline.endswith(".py"):
|
366
|
-
path = Path(custom_pipeline)
|
367
|
-
# decompose into folder & file
|
368
|
-
file_name = path.name
|
369
|
-
custom_pipeline = path.parent.absolute()
|
370
|
-
elif repo_id is not None:
|
371
|
-
file_name = f"{custom_pipeline}.py"
|
372
|
-
custom_pipeline = repo_id
|
373
|
-
else:
|
374
|
-
file_name = CUSTOM_PIPELINE_FILE_NAME
|
375
|
-
|
376
|
-
if repo_id is not None and hub_revision is not None:
|
377
|
-
# if we load the pipeline code from the Hub
|
378
|
-
# make sure to overwrite the `revison`
|
379
|
-
revision = hub_revision
|
380
|
-
|
381
|
-
return get_class_from_dynamic_module(
|
382
|
-
custom_pipeline,
|
383
|
-
module_file=file_name,
|
384
|
-
class_name=class_name,
|
385
|
-
cache_dir=cache_dir,
|
386
|
-
revision=revision,
|
387
|
-
)
|
388
|
-
|
389
|
-
if class_obj != DiffusionPipeline:
|
390
|
-
return class_obj
|
391
|
-
|
392
|
-
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
393
|
-
class_name = class_name or config["_class_name"]
|
394
|
-
if not class_name:
|
395
|
-
raise ValueError(
|
396
|
-
"The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`."
|
397
|
-
)
|
398
|
-
|
399
|
-
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
400
|
-
|
401
|
-
pipeline_cls = getattr(diffusers_module, class_name)
|
402
|
-
|
403
|
-
if load_connected_pipeline:
|
404
|
-
from .auto_pipeline import _get_connected_pipeline
|
405
|
-
|
406
|
-
connected_pipeline_cls = _get_connected_pipeline(pipeline_cls)
|
407
|
-
if connected_pipeline_cls is not None:
|
408
|
-
logger.info(
|
409
|
-
f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`"
|
410
|
-
)
|
411
|
-
else:
|
412
|
-
logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.")
|
413
|
-
|
414
|
-
pipeline_cls = connected_pipeline_cls or pipeline_cls
|
415
|
-
|
416
|
-
return pipeline_cls
|
417
|
-
|
418
|
-
|
419
|
-
def load_sub_model(
|
420
|
-
library_name: str,
|
421
|
-
class_name: str,
|
422
|
-
importable_classes: List[Any],
|
423
|
-
pipelines: Any,
|
424
|
-
is_pipeline_module: bool,
|
425
|
-
pipeline_class: Any,
|
426
|
-
torch_dtype: torch.dtype,
|
427
|
-
provider: Any,
|
428
|
-
sess_options: Any,
|
429
|
-
device_map: Optional[Union[Dict[str, torch.device], str]],
|
430
|
-
max_memory: Optional[Dict[Union[int, str], Union[int, str]]],
|
431
|
-
offload_folder: Optional[Union[str, os.PathLike]],
|
432
|
-
offload_state_dict: bool,
|
433
|
-
model_variants: Dict[str, str],
|
434
|
-
name: str,
|
435
|
-
from_flax: bool,
|
436
|
-
variant: str,
|
437
|
-
low_cpu_mem_usage: bool,
|
438
|
-
cached_folder: Union[str, os.PathLike],
|
439
|
-
revision: str = None,
|
440
|
-
):
|
441
|
-
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
442
|
-
# retrieve class candidates
|
443
|
-
class_obj, class_candidates = get_class_obj_and_candidates(
|
444
|
-
library_name,
|
445
|
-
class_name,
|
446
|
-
importable_classes,
|
447
|
-
pipelines,
|
448
|
-
is_pipeline_module,
|
449
|
-
component_name=name,
|
450
|
-
cache_dir=cached_folder,
|
451
|
-
)
|
452
|
-
|
453
|
-
load_method_name = None
|
454
|
-
# retrive load method name
|
455
|
-
for class_name, class_candidate in class_candidates.items():
|
456
|
-
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
457
|
-
load_method_name = importable_classes[class_name][1]
|
458
|
-
|
459
|
-
# if load method name is None, then we have a dummy module -> raise Error
|
460
|
-
if load_method_name is None:
|
461
|
-
none_module = class_obj.__module__
|
462
|
-
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
|
463
|
-
TRANSFORMERS_DUMMY_MODULES_FOLDER
|
464
|
-
)
|
465
|
-
if is_dummy_path and "dummy" in none_module:
|
466
|
-
# call class_obj for nice error message of missing requirements
|
467
|
-
class_obj()
|
468
|
-
|
469
|
-
raise ValueError(
|
470
|
-
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
|
471
|
-
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
472
|
-
)
|
473
|
-
|
474
|
-
load_method = getattr(class_obj, load_method_name)
|
475
|
-
|
476
|
-
# add kwargs to loading method
|
477
|
-
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
478
|
-
loading_kwargs = {}
|
479
|
-
if issubclass(class_obj, torch.nn.Module):
|
480
|
-
loading_kwargs["torch_dtype"] = torch_dtype
|
481
|
-
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
482
|
-
loading_kwargs["provider"] = provider
|
483
|
-
loading_kwargs["sess_options"] = sess_options
|
484
|
-
|
485
|
-
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
486
|
-
|
487
|
-
if is_transformers_available():
|
488
|
-
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
489
|
-
else:
|
490
|
-
transformers_version = "N/A"
|
491
|
-
|
492
|
-
is_transformers_model = (
|
493
|
-
is_transformers_available()
|
494
|
-
and issubclass(class_obj, PreTrainedModel)
|
495
|
-
and transformers_version >= version.parse("4.20.0")
|
496
|
-
)
|
497
|
-
|
498
|
-
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
499
|
-
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
500
|
-
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
501
|
-
if is_diffusers_model or is_transformers_model:
|
502
|
-
loading_kwargs["device_map"] = device_map
|
503
|
-
loading_kwargs["max_memory"] = max_memory
|
504
|
-
loading_kwargs["offload_folder"] = offload_folder
|
505
|
-
loading_kwargs["offload_state_dict"] = offload_state_dict
|
506
|
-
loading_kwargs["variant"] = model_variants.pop(name, None)
|
507
|
-
if from_flax:
|
508
|
-
loading_kwargs["from_flax"] = True
|
509
|
-
|
510
|
-
# the following can be deleted once the minimum required `transformers` version
|
511
|
-
# is higher than 4.27
|
512
|
-
if (
|
513
|
-
is_transformers_model
|
514
|
-
and loading_kwargs["variant"] is not None
|
515
|
-
and transformers_version < version.parse("4.27.0")
|
516
|
-
):
|
517
|
-
raise ImportError(
|
518
|
-
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0"
|
519
|
-
)
|
520
|
-
elif is_transformers_model and loading_kwargs["variant"] is None:
|
521
|
-
loading_kwargs.pop("variant")
|
522
|
-
|
523
|
-
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage`
|
524
|
-
if not (from_flax and is_transformers_model):
|
525
|
-
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
526
|
-
else:
|
527
|
-
loading_kwargs["low_cpu_mem_usage"] = False
|
528
|
-
|
529
|
-
# check if the module is in a subdirectory
|
530
|
-
if os.path.isdir(os.path.join(cached_folder, name)):
|
531
|
-
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
532
|
-
else:
|
533
|
-
# else load from the root directory
|
534
|
-
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
535
|
-
|
536
|
-
return loaded_sub_model
|
537
|
-
|
538
|
-
|
539
|
-
def _fetch_class_library_tuple(module):
|
540
|
-
# import it here to avoid circular import
|
541
|
-
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
542
|
-
pipelines = getattr(diffusers_module, "pipelines")
|
543
|
-
|
544
|
-
# register the config from the original module, not the dynamo compiled one
|
545
|
-
not_compiled_module = _unwrap_model(module)
|
546
|
-
library = not_compiled_module.__module__.split(".")[0]
|
547
|
-
|
548
|
-
# check if the module is a pipeline module
|
549
|
-
module_path_items = not_compiled_module.__module__.split(".")
|
550
|
-
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
551
|
-
|
552
|
-
path = not_compiled_module.__module__.split(".")
|
553
|
-
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
554
|
-
|
555
|
-
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
556
|
-
# Or if it's a pipeline module, then the module is inside the pipeline
|
557
|
-
# folder so we set the library to module name.
|
558
|
-
if is_pipeline_module:
|
559
|
-
library = pipeline_dir
|
560
|
-
elif library not in LOADABLE_CLASSES:
|
561
|
-
library = not_compiled_module.__module__
|
562
|
-
|
563
|
-
# retrieve class_name
|
564
|
-
class_name = not_compiled_module.__class__.__name__
|
565
|
-
|
566
|
-
return (library, class_name)
|
567
|
-
|
568
|
-
|
569
123
|
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
570
124
|
r"""
|
571
125
|
Base class for all pipelines.
|
@@ -702,7 +256,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
702
256
|
break
|
703
257
|
|
704
258
|
if save_method_name is None:
|
705
|
-
logger.
|
259
|
+
logger.warning(
|
260
|
+
f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved."
|
261
|
+
)
|
706
262
|
# make sure that unsaveable components are not tried to be loaded afterward
|
707
263
|
self.register_to_config(**{pipeline_component_name: (None, None)})
|
708
264
|
continue
|
@@ -775,32 +331,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
775
331
|
Returns:
|
776
332
|
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
|
777
333
|
"""
|
778
|
-
|
779
|
-
|
780
|
-
if torch_dtype is not None:
|
781
|
-
deprecate("torch_dtype", "0.27.0", "")
|
782
|
-
torch_device = kwargs.pop("torch_device", None)
|
783
|
-
if torch_device is not None:
|
784
|
-
deprecate("torch_device", "0.27.0", "")
|
785
|
-
|
786
|
-
dtype_kwarg = kwargs.pop("dtype", None)
|
787
|
-
device_kwarg = kwargs.pop("device", None)
|
334
|
+
dtype = kwargs.pop("dtype", None)
|
335
|
+
device = kwargs.pop("device", None)
|
788
336
|
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
|
789
337
|
|
790
|
-
if torch_dtype is not None and dtype_kwarg is not None:
|
791
|
-
raise ValueError(
|
792
|
-
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
|
793
|
-
)
|
794
|
-
|
795
|
-
dtype = torch_dtype or dtype_kwarg
|
796
|
-
|
797
|
-
if torch_device is not None and device_kwarg is not None:
|
798
|
-
raise ValueError(
|
799
|
-
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
|
800
|
-
)
|
801
|
-
|
802
|
-
device = torch_device or device_kwarg
|
803
|
-
|
804
338
|
dtype_arg = None
|
805
339
|
device_arg = None
|
806
340
|
if len(args) == 1:
|
@@ -873,12 +407,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
873
407
|
|
874
408
|
if is_loaded_in_8bit and dtype is not None:
|
875
409
|
logger.warning(
|
876
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {
|
410
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
|
877
411
|
)
|
878
412
|
|
879
413
|
if is_loaded_in_8bit and device is not None:
|
880
414
|
logger.warning(
|
881
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {
|
415
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
882
416
|
)
|
883
417
|
else:
|
884
418
|
module.to(device, dtype)
|
@@ -1003,10 +537,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1003
537
|
revision (`str`, *optional*, defaults to `"main"`):
|
1004
538
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1005
539
|
allowed by Git.
|
1006
|
-
custom_revision (`str`, *optional
|
540
|
+
custom_revision (`str`, *optional*):
|
1007
541
|
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
1008
|
-
`revision` when loading a custom pipeline from the Hub.
|
1009
|
-
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
542
|
+
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
|
1010
543
|
mirror (`str`, *optional*):
|
1011
544
|
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
1012
545
|
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
@@ -1100,6 +633,33 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1100
633
|
use_onnx = kwargs.pop("use_onnx", None)
|
1101
634
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
1102
635
|
|
636
|
+
if low_cpu_mem_usage and not is_accelerate_available():
|
637
|
+
low_cpu_mem_usage = False
|
638
|
+
logger.warning(
|
639
|
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
640
|
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
641
|
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
642
|
+
" install accelerate\n```\n."
|
643
|
+
)
|
644
|
+
|
645
|
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
646
|
+
raise NotImplementedError(
|
647
|
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
648
|
+
" `device_map=None`."
|
649
|
+
)
|
650
|
+
|
651
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
652
|
+
raise NotImplementedError(
|
653
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
654
|
+
" `low_cpu_mem_usage=False`."
|
655
|
+
)
|
656
|
+
|
657
|
+
if low_cpu_mem_usage is False and device_map is not None:
|
658
|
+
raise ValueError(
|
659
|
+
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
660
|
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
661
|
+
)
|
662
|
+
|
1103
663
|
# 1. Download the checkpoints and configs
|
1104
664
|
# use snapshot download here to get it working from from_pretrained
|
1105
665
|
if not os.path.isdir(pretrained_model_name_or_path):
|
@@ -1232,33 +792,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1232
792
|
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
1233
793
|
)
|
1234
794
|
|
1235
|
-
if low_cpu_mem_usage and not is_accelerate_available():
|
1236
|
-
low_cpu_mem_usage = False
|
1237
|
-
logger.warning(
|
1238
|
-
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
1239
|
-
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
1240
|
-
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
1241
|
-
" install accelerate\n```\n."
|
1242
|
-
)
|
1243
|
-
|
1244
|
-
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
1245
|
-
raise NotImplementedError(
|
1246
|
-
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1247
|
-
" `device_map=None`."
|
1248
|
-
)
|
1249
|
-
|
1250
|
-
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
1251
|
-
raise NotImplementedError(
|
1252
|
-
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
1253
|
-
" `low_cpu_mem_usage=False`."
|
1254
|
-
)
|
1255
|
-
|
1256
|
-
if low_cpu_mem_usage is False and device_map is not None:
|
1257
|
-
raise ValueError(
|
1258
|
-
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
1259
|
-
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
1260
|
-
)
|
1261
|
-
|
1262
795
|
# import it here to avoid circular import
|
1263
796
|
from diffusers import pipelines
|
1264
797
|
|
@@ -1303,7 +836,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1303
836
|
variant=variant,
|
1304
837
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
1305
838
|
cached_folder=cached_folder,
|
1306
|
-
revision=revision,
|
1307
839
|
)
|
1308
840
|
logger.info(
|
1309
841
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
@@ -1445,6 +977,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1445
977
|
|
1446
978
|
device_type = torch_device.type
|
1447
979
|
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
980
|
+
self._offload_device = device
|
1448
981
|
|
1449
982
|
if self.device.type != "cpu":
|
1450
983
|
self.to("cpu", silence_dtype_warnings=True)
|
@@ -1494,7 +1027,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1494
1027
|
hook.remove()
|
1495
1028
|
|
1496
1029
|
# make sure the model is in the same state as before calling it
|
1497
|
-
self.enable_model_cpu_offload()
|
1030
|
+
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
1498
1031
|
|
1499
1032
|
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
1500
1033
|
r"""
|
@@ -1530,6 +1063,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1530
1063
|
|
1531
1064
|
device_type = torch_device.type
|
1532
1065
|
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
1066
|
+
self._offload_device = device
|
1533
1067
|
|
1534
1068
|
if self.device.type != "cpu":
|
1535
1069
|
self.to("cpu", silence_dtype_warnings=True)
|
@@ -1670,7 +1204,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1670
1204
|
try:
|
1671
1205
|
info = model_info(pretrained_model_name, token=token, revision=revision)
|
1672
1206
|
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
1673
|
-
logger.
|
1207
|
+
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
1674
1208
|
local_files_only = True
|
1675
1209
|
model_info_call_error = e # save error to reraise it if model is not cached locally
|
1676
1210
|
|
@@ -1821,7 +1355,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1821
1355
|
len(safetensors_variant_filenames) > 0
|
1822
1356
|
and safetensors_model_filenames != safetensors_variant_filenames
|
1823
1357
|
):
|
1824
|
-
logger.
|
1358
|
+
logger.warning(
|
1825
1359
|
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1826
1360
|
)
|
1827
1361
|
else:
|
@@ -1834,7 +1368,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1834
1368
|
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
1835
1369
|
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
1836
1370
|
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
1837
|
-
logger.
|
1371
|
+
logger.warning(
|
1838
1372
|
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1839
1373
|
)
|
1840
1374
|
|
@@ -1918,7 +1452,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1918
1452
|
else:
|
1919
1453
|
# 2. we forced `local_files_only=True` when `model_info` failed
|
1920
1454
|
raise EnvironmentError(
|
1921
|
-
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error
|
1455
|
+
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
|
1922
1456
|
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
|
1923
1457
|
" above."
|
1924
1458
|
) from model_info_call_error
|
@@ -2115,3 +1649,123 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
2115
1649
|
|
2116
1650
|
for module in modules:
|
2117
1651
|
module.set_attention_slice(slice_size)
|
1652
|
+
|
1653
|
+
|
1654
|
+
class StableDiffusionMixin:
|
1655
|
+
r"""
|
1656
|
+
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
|
1657
|
+
"""
|
1658
|
+
|
1659
|
+
def enable_vae_slicing(self):
|
1660
|
+
r"""
|
1661
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
1662
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
1663
|
+
"""
|
1664
|
+
self.vae.enable_slicing()
|
1665
|
+
|
1666
|
+
def disable_vae_slicing(self):
|
1667
|
+
r"""
|
1668
|
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
1669
|
+
computing decoding in one step.
|
1670
|
+
"""
|
1671
|
+
self.vae.disable_slicing()
|
1672
|
+
|
1673
|
+
def enable_vae_tiling(self):
|
1674
|
+
r"""
|
1675
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
1676
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
1677
|
+
processing larger images.
|
1678
|
+
"""
|
1679
|
+
self.vae.enable_tiling()
|
1680
|
+
|
1681
|
+
def disable_vae_tiling(self):
|
1682
|
+
r"""
|
1683
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
1684
|
+
computing decoding in one step.
|
1685
|
+
"""
|
1686
|
+
self.vae.disable_tiling()
|
1687
|
+
|
1688
|
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
1689
|
+
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
1690
|
+
|
1691
|
+
The suffixes after the scaling factors represent the stages where they are being applied.
|
1692
|
+
|
1693
|
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
1694
|
+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
1695
|
+
|
1696
|
+
Args:
|
1697
|
+
s1 (`float`):
|
1698
|
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
1699
|
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
1700
|
+
s2 (`float`):
|
1701
|
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
1702
|
+
mitigate "oversmoothing effect" in the enhanced denoising process.
|
1703
|
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
1704
|
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
1705
|
+
"""
|
1706
|
+
if not hasattr(self, "unet"):
|
1707
|
+
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
1708
|
+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
1709
|
+
|
1710
|
+
def disable_freeu(self):
|
1711
|
+
"""Disables the FreeU mechanism if enabled."""
|
1712
|
+
self.unet.disable_freeu()
|
1713
|
+
|
1714
|
+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
1715
|
+
"""
|
1716
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1717
|
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
1718
|
+
|
1719
|
+
<Tip warning={true}>
|
1720
|
+
|
1721
|
+
This API is 🧪 experimental.
|
1722
|
+
|
1723
|
+
</Tip>
|
1724
|
+
|
1725
|
+
Args:
|
1726
|
+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
1727
|
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
1728
|
+
"""
|
1729
|
+
self.fusing_unet = False
|
1730
|
+
self.fusing_vae = False
|
1731
|
+
|
1732
|
+
if unet:
|
1733
|
+
self.fusing_unet = True
|
1734
|
+
self.unet.fuse_qkv_projections()
|
1735
|
+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
1736
|
+
|
1737
|
+
if vae:
|
1738
|
+
if not isinstance(self.vae, AutoencoderKL):
|
1739
|
+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
1740
|
+
|
1741
|
+
self.fusing_vae = True
|
1742
|
+
self.vae.fuse_qkv_projections()
|
1743
|
+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
1744
|
+
|
1745
|
+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
1746
|
+
"""Disable QKV projection fusion if enabled.
|
1747
|
+
|
1748
|
+
<Tip warning={true}>
|
1749
|
+
|
1750
|
+
This API is 🧪 experimental.
|
1751
|
+
|
1752
|
+
</Tip>
|
1753
|
+
|
1754
|
+
Args:
|
1755
|
+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
1756
|
+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
1757
|
+
|
1758
|
+
"""
|
1759
|
+
if unet:
|
1760
|
+
if not self.fusing_unet:
|
1761
|
+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
1762
|
+
else:
|
1763
|
+
self.unet.unfuse_qkv_projections()
|
1764
|
+
self.fusing_unet = False
|
1765
|
+
|
1766
|
+
if vae:
|
1767
|
+
if not self.fusing_vae:
|
1768
|
+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
1769
|
+
else:
|
1770
|
+
self.vae.unfuse_qkv_projections()
|
1771
|
+
self.fusing_vae = False
|