diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -2
- diffusers/commands/fp16_safetensors.py +10 -11
- diffusers/configuration_utils.py +13 -8
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +5 -5
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/image_processor.py +463 -51
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +159 -0
- diffusers/loaders/lora.py +1553 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +637 -0
- diffusers/loaders/textual_inversion.py +455 -0
- diffusers/loaders/unet.py +828 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +26 -9
- diffusers/models/activations.py +9 -6
- diffusers/models/attention.py +301 -29
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +378 -6
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
- diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
- diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
- diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/downsampling.py +338 -0
- diffusers/models/embeddings.py +112 -29
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +14 -8
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +37 -29
- diffusers/models/normalization.py +110 -4
- diffusers/models/resnet.py +299 -652
- diffusers/models/transformer_2d.py +22 -5
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +46 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandinsky3.py +535 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/upsampling.py +454 -0
- diffusers/models/uvit_2d.py +471 -0
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +12 -3
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +137 -76
- diffusers/pipelines/amused/__init__.py +62 -0
- diffusers/pipelines/amused/pipeline_amused.py +328 -0
- diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +23 -13
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/deprecated/__init__.py +153 -0
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
- diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
- diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
- diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
- diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
- diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
- diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
- diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
- diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
- diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
- diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
- diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
- diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
- diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
- diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/onnx_utils.py +8 -5
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +11 -8
- diffusers/pipelines/pipeline_utils.py +63 -42
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/__init__.py +37 -65
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
- diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
- diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
- diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
- diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
- diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
- diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
- diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
- diffusers/schedulers/__init__.py +4 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_amused.py +162 -0
- diffusers/schedulers/scheduling_consistency_models.py +2 -0
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +47 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
- diffusers/schedulers/scheduling_deis_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
- diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
- diffusers/schedulers/scheduling_euler_discrete.py +102 -16
- diffusers/schedulers/scheduling_heun_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +3 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
- diffusers/schedulers/scheduling_utils.py +3 -1
- diffusers/schedulers/scheduling_utils_flax.py +3 -1
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +1 -2
- diffusers/utils/constants.py +10 -12
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +18 -22
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/hub_utils.py +24 -36
- diffusers/utils/logging.py +11 -11
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/state_dict_utils.py +8 -0
- diffusers/utils/testing_utils.py +199 -1
- diffusers/utils/torch_utils.py +4 -4
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
- diffusers-0.25.0.dist-info/RECORD +360 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
- diffusers/loaders.py +0 -3336
- diffusers-0.23.1.dist-info/RECORD +0 -323
- /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
diffusers/utils/hub_utils.py
CHANGED
@@ -25,20 +25,21 @@ from typing import Dict, Optional, Union
|
|
25
25
|
from uuid import uuid4
|
26
26
|
|
27
27
|
from huggingface_hub import (
|
28
|
-
HfFolder,
|
29
28
|
ModelCard,
|
30
29
|
ModelCardData,
|
31
30
|
create_repo,
|
31
|
+
get_full_repo_name,
|
32
32
|
hf_hub_download,
|
33
33
|
upload_folder,
|
34
|
-
whoami,
|
35
34
|
)
|
35
|
+
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
|
36
36
|
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
37
37
|
from huggingface_hub.utils import (
|
38
38
|
EntryNotFoundError,
|
39
39
|
RepositoryNotFoundError,
|
40
40
|
RevisionNotFoundError,
|
41
41
|
is_jinja_available,
|
42
|
+
validate_hf_hub_args,
|
42
43
|
)
|
43
44
|
from packaging import version
|
44
45
|
from requests import HTTPError
|
@@ -46,7 +47,6 @@ from requests import HTTPError
|
|
46
47
|
from .. import __version__
|
47
48
|
from .constants import (
|
48
49
|
DEPRECATED_REVISION_ARGS,
|
49
|
-
DIFFUSERS_CACHE,
|
50
50
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
51
51
|
SAFETENSORS_WEIGHTS_NAME,
|
52
52
|
WEIGHTS_NAME,
|
@@ -69,9 +69,6 @@ logger = get_logger(__name__)
|
|
69
69
|
|
70
70
|
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
|
71
71
|
SESSION_ID = uuid4().hex
|
72
|
-
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
|
73
|
-
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
74
|
-
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
|
75
72
|
|
76
73
|
|
77
74
|
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
@@ -79,7 +76,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|
79
76
|
Formats a user-agent string with basic info about a request.
|
80
77
|
"""
|
81
78
|
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
82
|
-
if
|
79
|
+
if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE:
|
83
80
|
return ua + "; telemetry/off"
|
84
81
|
if is_torch_available():
|
85
82
|
ua += f"; torch/{_torch_version}"
|
@@ -98,16 +95,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
|
98
95
|
return ua
|
99
96
|
|
100
97
|
|
101
|
-
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
102
|
-
if token is None:
|
103
|
-
token = HfFolder.get_token()
|
104
|
-
if organization is None:
|
105
|
-
username = whoami(token)["name"]
|
106
|
-
return f"{username}/{model_id}"
|
107
|
-
else:
|
108
|
-
return f"{organization}/{model_id}"
|
109
|
-
|
110
|
-
|
111
98
|
def create_model_card(args, model_name):
|
112
99
|
if not is_jinja_available():
|
113
100
|
raise ValueError(
|
@@ -183,7 +170,7 @@ old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
|
|
183
170
|
|
184
171
|
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
|
185
172
|
if new_cache_dir is None:
|
186
|
-
new_cache_dir =
|
173
|
+
new_cache_dir = HF_HUB_CACHE
|
187
174
|
if old_cache_dir is None:
|
188
175
|
old_cache_dir = old_diffusers_cache
|
189
176
|
|
@@ -203,7 +190,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
|
|
203
190
|
# At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
|
204
191
|
|
205
192
|
|
206
|
-
cache_version_file = os.path.join(
|
193
|
+
cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
|
207
194
|
if not os.path.isfile(cache_version_file):
|
208
195
|
cache_version = 0
|
209
196
|
else:
|
@@ -233,12 +220,12 @@ if cache_version < 1:
|
|
233
220
|
|
234
221
|
if cache_version < 1:
|
235
222
|
try:
|
236
|
-
os.makedirs(
|
223
|
+
os.makedirs(HF_HUB_CACHE, exist_ok=True)
|
237
224
|
with open(cache_version_file, "w") as f:
|
238
225
|
f.write("1")
|
239
226
|
except Exception:
|
240
227
|
logger.warning(
|
241
|
-
f"There was a problem when trying to write in your cache folder ({
|
228
|
+
f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
|
242
229
|
"the directory exists and can be written to."
|
243
230
|
)
|
244
231
|
|
@@ -252,20 +239,21 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
|
252
239
|
return weights_name
|
253
240
|
|
254
241
|
|
242
|
+
@validate_hf_hub_args
|
255
243
|
def _get_model_file(
|
256
|
-
pretrained_model_name_or_path,
|
244
|
+
pretrained_model_name_or_path: Union[str, Path],
|
257
245
|
*,
|
258
|
-
weights_name,
|
259
|
-
subfolder,
|
260
|
-
cache_dir,
|
261
|
-
force_download,
|
262
|
-
proxies,
|
263
|
-
resume_download,
|
264
|
-
local_files_only,
|
265
|
-
|
266
|
-
user_agent,
|
267
|
-
revision,
|
268
|
-
commit_hash=None,
|
246
|
+
weights_name: str,
|
247
|
+
subfolder: Optional[str],
|
248
|
+
cache_dir: Optional[str],
|
249
|
+
force_download: bool,
|
250
|
+
proxies: Optional[Dict],
|
251
|
+
resume_download: bool,
|
252
|
+
local_files_only: bool,
|
253
|
+
token: Optional[str],
|
254
|
+
user_agent: Union[Dict, str, None],
|
255
|
+
revision: Optional[str],
|
256
|
+
commit_hash: Optional[str] = None,
|
269
257
|
):
|
270
258
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
271
259
|
if os.path.isfile(pretrained_model_name_or_path):
|
@@ -300,7 +288,7 @@ def _get_model_file(
|
|
300
288
|
proxies=proxies,
|
301
289
|
resume_download=resume_download,
|
302
290
|
local_files_only=local_files_only,
|
303
|
-
|
291
|
+
token=token,
|
304
292
|
user_agent=user_agent,
|
305
293
|
subfolder=subfolder,
|
306
294
|
revision=revision or commit_hash,
|
@@ -325,7 +313,7 @@ def _get_model_file(
|
|
325
313
|
proxies=proxies,
|
326
314
|
resume_download=resume_download,
|
327
315
|
local_files_only=local_files_only,
|
328
|
-
|
316
|
+
token=token,
|
329
317
|
user_agent=user_agent,
|
330
318
|
subfolder=subfolder,
|
331
319
|
revision=revision or commit_hash,
|
@@ -336,7 +324,7 @@ def _get_model_file(
|
|
336
324
|
raise EnvironmentError(
|
337
325
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
338
326
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
339
|
-
"token having permission to this repo with `
|
327
|
+
"token having permission to this repo with `token` or log in with `huggingface-cli "
|
340
328
|
"login`."
|
341
329
|
)
|
342
330
|
except RevisionNotFoundError:
|
diffusers/utils/logging.py
CHANGED
@@ -28,7 +28,7 @@ from logging import (
|
|
28
28
|
WARN, # NOQA
|
29
29
|
WARNING, # NOQA
|
30
30
|
)
|
31
|
-
from typing import Optional
|
31
|
+
from typing import Dict, Optional
|
32
32
|
|
33
33
|
from tqdm import auto as tqdm_lib
|
34
34
|
|
@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING
|
|
49
49
|
_tqdm_active = True
|
50
50
|
|
51
51
|
|
52
|
-
def _get_default_logging_level():
|
52
|
+
def _get_default_logging_level() -> int:
|
53
53
|
"""
|
54
54
|
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
55
55
|
not - fall back to `_default_log_level`
|
@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
|
|
104
104
|
_default_handler = None
|
105
105
|
|
106
106
|
|
107
|
-
def get_log_levels_dict():
|
107
|
+
def get_log_levels_dict() -> Dict[str, int]:
|
108
108
|
return log_levels
|
109
109
|
|
110
110
|
|
@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
|
|
161
161
|
_get_library_root_logger().setLevel(verbosity)
|
162
162
|
|
163
163
|
|
164
|
-
def set_verbosity_info():
|
164
|
+
def set_verbosity_info() -> None:
|
165
165
|
"""Set the verbosity to the `INFO` level."""
|
166
166
|
return set_verbosity(INFO)
|
167
167
|
|
168
168
|
|
169
|
-
def set_verbosity_warning():
|
169
|
+
def set_verbosity_warning() -> None:
|
170
170
|
"""Set the verbosity to the `WARNING` level."""
|
171
171
|
return set_verbosity(WARNING)
|
172
172
|
|
173
173
|
|
174
|
-
def set_verbosity_debug():
|
174
|
+
def set_verbosity_debug() -> None:
|
175
175
|
"""Set the verbosity to the `DEBUG` level."""
|
176
176
|
return set_verbosity(DEBUG)
|
177
177
|
|
178
178
|
|
179
|
-
def set_verbosity_error():
|
179
|
+
def set_verbosity_error() -> None:
|
180
180
|
"""Set the verbosity to the `ERROR` level."""
|
181
181
|
return set_verbosity(ERROR)
|
182
182
|
|
@@ -213,7 +213,7 @@ def remove_handler(handler: logging.Handler) -> None:
|
|
213
213
|
|
214
214
|
_configure_library_root_logger()
|
215
215
|
|
216
|
-
assert handler is not None and handler
|
216
|
+
assert handler is not None and handler in _get_library_root_logger().handlers
|
217
217
|
_get_library_root_logger().removeHandler(handler)
|
218
218
|
|
219
219
|
|
@@ -263,7 +263,7 @@ def reset_format() -> None:
|
|
263
263
|
handler.setFormatter(None)
|
264
264
|
|
265
265
|
|
266
|
-
def warning_advice(self, *args, **kwargs):
|
266
|
+
def warning_advice(self, *args, **kwargs) -> None:
|
267
267
|
"""
|
268
268
|
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
269
269
|
warning will not be printed
|
@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
|
|
327
327
|
return bool(_tqdm_active)
|
328
328
|
|
329
329
|
|
330
|
-
def enable_progress_bar():
|
330
|
+
def enable_progress_bar() -> None:
|
331
331
|
"""Enable tqdm progress bar."""
|
332
332
|
global _tqdm_active
|
333
333
|
_tqdm_active = True
|
334
334
|
|
335
335
|
|
336
|
-
def disable_progress_bar():
|
336
|
+
def disable_progress_bar() -> None:
|
337
337
|
"""Disable tqdm progress bar."""
|
338
338
|
global _tqdm_active
|
339
339
|
_tqdm_active = False
|
diffusers/utils/outputs.py
CHANGED
@@ -24,7 +24,7 @@ import numpy as np
|
|
24
24
|
from .import_utils import is_torch_available
|
25
25
|
|
26
26
|
|
27
|
-
def is_tensor(x):
|
27
|
+
def is_tensor(x) -> bool:
|
28
28
|
"""
|
29
29
|
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
30
30
|
"""
|
@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict):
|
|
66
66
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
67
67
|
)
|
68
68
|
|
69
|
-
def __post_init__(self):
|
69
|
+
def __post_init__(self) -> None:
|
70
70
|
class_fields = fields(self)
|
71
71
|
|
72
72
|
# Safety and consistency checks
|
@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict):
|
|
97
97
|
def update(self, *args, **kwargs):
|
98
98
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
99
99
|
|
100
|
-
def __getitem__(self, k):
|
100
|
+
def __getitem__(self, k: Any) -> Any:
|
101
101
|
if isinstance(k, str):
|
102
102
|
inner_dict = dict(self.items())
|
103
103
|
return inner_dict[k]
|
104
104
|
else:
|
105
105
|
return self.to_tuple()[k]
|
106
106
|
|
107
|
-
def __setattr__(self, name, value):
|
107
|
+
def __setattr__(self, name: Any, value: Any) -> None:
|
108
108
|
if name in self.keys() and value is not None:
|
109
109
|
# Don't call self.__setitem__ to avoid recursion errors
|
110
110
|
super().__setitem__(name, value)
|
@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict):
|
|
123
123
|
args = tuple(getattr(self, field.name) for field in fields(self))
|
124
124
|
return callable, args, *remaining
|
125
125
|
|
126
|
-
def to_tuple(self) -> Tuple[Any]:
|
126
|
+
def to_tuple(self) -> Tuple[Any, ...]:
|
127
127
|
"""
|
128
128
|
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
129
129
|
"""
|
diffusers/utils/peft_utils.py
CHANGED
@@ -23,55 +23,77 @@ from packaging import version
|
|
23
23
|
from .import_utils import is_peft_available, is_torch_available
|
24
24
|
|
25
25
|
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
if is_torch_available():
|
27
|
+
import torch
|
28
|
+
|
29
29
|
|
30
|
+
def recurse_remove_peft_layers(model):
|
30
31
|
r"""
|
31
32
|
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
32
33
|
"""
|
33
|
-
from peft.tuners.
|
34
|
-
|
35
|
-
for name, module in model.named_children():
|
36
|
-
if len(list(module.children())) > 0:
|
37
|
-
## compound module, go inside it
|
38
|
-
recurse_remove_peft_layers(module)
|
39
|
-
|
40
|
-
module_replaced = False
|
41
|
-
|
42
|
-
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
43
|
-
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
44
|
-
module.weight.device
|
45
|
-
)
|
46
|
-
new_module.weight = module.weight
|
47
|
-
if module.bias is not None:
|
48
|
-
new_module.bias = module.bias
|
49
|
-
|
50
|
-
module_replaced = True
|
51
|
-
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
52
|
-
new_module = torch.nn.Conv2d(
|
53
|
-
module.in_channels,
|
54
|
-
module.out_channels,
|
55
|
-
module.kernel_size,
|
56
|
-
module.stride,
|
57
|
-
module.padding,
|
58
|
-
module.dilation,
|
59
|
-
module.groups,
|
60
|
-
).to(module.weight.device)
|
61
|
-
|
62
|
-
new_module.weight = module.weight
|
63
|
-
if module.bias is not None:
|
64
|
-
new_module.bias = module.bias
|
65
|
-
|
66
|
-
module_replaced = True
|
67
|
-
|
68
|
-
if module_replaced:
|
69
|
-
setattr(model, name, new_module)
|
70
|
-
del module
|
71
|
-
|
72
|
-
if torch.cuda.is_available():
|
73
|
-
torch.cuda.empty_cache()
|
34
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
74
35
|
|
36
|
+
has_base_layer_pattern = False
|
37
|
+
for module in model.modules():
|
38
|
+
if isinstance(module, BaseTunerLayer):
|
39
|
+
has_base_layer_pattern = hasattr(module, "base_layer")
|
40
|
+
break
|
41
|
+
|
42
|
+
if has_base_layer_pattern:
|
43
|
+
from peft.utils import _get_submodules
|
44
|
+
|
45
|
+
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
46
|
+
for key in key_list:
|
47
|
+
try:
|
48
|
+
parent, target, target_name = _get_submodules(model, key)
|
49
|
+
except AttributeError:
|
50
|
+
continue
|
51
|
+
if hasattr(target, "base_layer"):
|
52
|
+
setattr(parent, target_name, target.get_base_layer())
|
53
|
+
else:
|
54
|
+
# This is for backwards compatibility with PEFT <= 0.6.2.
|
55
|
+
# TODO can be removed once that PEFT version is no longer supported.
|
56
|
+
from peft.tuners.lora import LoraLayer
|
57
|
+
|
58
|
+
for name, module in model.named_children():
|
59
|
+
if len(list(module.children())) > 0:
|
60
|
+
## compound module, go inside it
|
61
|
+
recurse_remove_peft_layers(module)
|
62
|
+
|
63
|
+
module_replaced = False
|
64
|
+
|
65
|
+
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
66
|
+
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
67
|
+
module.weight.device
|
68
|
+
)
|
69
|
+
new_module.weight = module.weight
|
70
|
+
if module.bias is not None:
|
71
|
+
new_module.bias = module.bias
|
72
|
+
|
73
|
+
module_replaced = True
|
74
|
+
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
75
|
+
new_module = torch.nn.Conv2d(
|
76
|
+
module.in_channels,
|
77
|
+
module.out_channels,
|
78
|
+
module.kernel_size,
|
79
|
+
module.stride,
|
80
|
+
module.padding,
|
81
|
+
module.dilation,
|
82
|
+
module.groups,
|
83
|
+
).to(module.weight.device)
|
84
|
+
|
85
|
+
new_module.weight = module.weight
|
86
|
+
if module.bias is not None:
|
87
|
+
new_module.bias = module.bias
|
88
|
+
|
89
|
+
module_replaced = True
|
90
|
+
|
91
|
+
if module_replaced:
|
92
|
+
setattr(model, name, new_module)
|
93
|
+
del module
|
94
|
+
|
95
|
+
if torch.cuda.is_available():
|
96
|
+
torch.cuda.empty_cache()
|
75
97
|
return model
|
76
98
|
|
77
99
|
|
@@ -180,6 +202,28 @@ def set_adapter_layers(model, enabled=True):
|
|
180
202
|
module.disable_adapters = not enabled
|
181
203
|
|
182
204
|
|
205
|
+
def delete_adapter_layers(model, adapter_name):
|
206
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
207
|
+
|
208
|
+
for module in model.modules():
|
209
|
+
if isinstance(module, BaseTunerLayer):
|
210
|
+
if hasattr(module, "delete_adapter"):
|
211
|
+
module.delete_adapter(adapter_name)
|
212
|
+
else:
|
213
|
+
raise ValueError(
|
214
|
+
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
|
215
|
+
)
|
216
|
+
|
217
|
+
# For transformers integration - we need to pop the adapter from the config
|
218
|
+
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
|
219
|
+
model.peft_config.pop(adapter_name, None)
|
220
|
+
# In case all adapters are deleted, we need to delete the config
|
221
|
+
# and make sure to set the flag to False
|
222
|
+
if len(model.peft_config) == 0:
|
223
|
+
del model.peft_config
|
224
|
+
model._hf_peft_config_loaded = None
|
225
|
+
|
226
|
+
|
183
227
|
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
184
228
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
185
229
|
|
@@ -79,6 +79,14 @@ PEFT_TO_DIFFUSERS = {
|
|
79
79
|
".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
|
80
80
|
".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
|
81
81
|
".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
|
82
|
+
"to_k.lora_A": "to_k.lora.down",
|
83
|
+
"to_k.lora_B": "to_k.lora.up",
|
84
|
+
"to_q.lora_A": "to_q.lora.down",
|
85
|
+
"to_q.lora_B": "to_q.lora.up",
|
86
|
+
"to_v.lora_A": "to_v.lora.down",
|
87
|
+
"to_v.lora_B": "to_v.lora.up",
|
88
|
+
"to_out.0.lora_A": "to_out.0.lora.down",
|
89
|
+
"to_out.0.lora_B": "to_out.0.lora.up",
|
82
90
|
}
|
83
91
|
|
84
92
|
DIFFUSERS_OLD_TO_DIFFUSERS = {
|
diffusers/utils/testing_utils.py
CHANGED
@@ -17,7 +17,7 @@ from contextlib import contextmanager
|
|
17
17
|
from distutils.util import strtobool
|
18
18
|
from io import BytesIO, StringIO
|
19
19
|
from pathlib import Path
|
20
|
-
from typing import List, Optional, Union
|
20
|
+
from typing import Callable, Dict, List, Optional, Union
|
21
21
|
|
22
22
|
import numpy as np
|
23
23
|
import PIL.Image
|
@@ -58,6 +58,17 @@ USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
|
58
58
|
if is_torch_available():
|
59
59
|
import torch
|
60
60
|
|
61
|
+
# Set a backend environment variable for any extra module import required for a custom accelerator
|
62
|
+
if "DIFFUSERS_TEST_BACKEND" in os.environ:
|
63
|
+
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
|
64
|
+
try:
|
65
|
+
_ = importlib.import_module(backend)
|
66
|
+
except ModuleNotFoundError as e:
|
67
|
+
raise ModuleNotFoundError(
|
68
|
+
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
|
69
|
+
to enable a specified backend.):\n{e}"
|
70
|
+
) from e
|
71
|
+
|
61
72
|
if "DIFFUSERS_TEST_DEVICE" in os.environ:
|
62
73
|
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
|
63
74
|
try:
|
@@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
|
|
210
221
|
)
|
211
222
|
|
212
223
|
|
224
|
+
# These decorators are for accelerator-specific behaviours that are not GPU-specific
|
225
|
+
def require_torch_accelerator(test_case):
|
226
|
+
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
|
227
|
+
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
|
228
|
+
test_case
|
229
|
+
)
|
230
|
+
|
231
|
+
|
232
|
+
def require_torch_accelerator_with_fp16(test_case):
|
233
|
+
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
|
234
|
+
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
|
235
|
+
test_case
|
236
|
+
)
|
237
|
+
|
238
|
+
|
239
|
+
def require_torch_accelerator_with_fp64(test_case):
|
240
|
+
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
|
241
|
+
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
|
242
|
+
test_case
|
243
|
+
)
|
244
|
+
|
245
|
+
|
246
|
+
def require_torch_accelerator_with_training(test_case):
|
247
|
+
"""Decorator marking a test that requires an accelerator with support for training."""
|
248
|
+
return unittest.skipUnless(
|
249
|
+
is_torch_available() and backend_supports_training(torch_device),
|
250
|
+
"test requires accelerator with training support",
|
251
|
+
)(test_case)
|
252
|
+
|
253
|
+
|
213
254
|
def skip_mps(test_case):
|
214
255
|
"""Decorator marking a test to skip if torch_device is 'mps'"""
|
215
256
|
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
|
@@ -259,6 +300,23 @@ def require_peft_backend(test_case):
|
|
259
300
|
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
260
301
|
|
261
302
|
|
303
|
+
def require_peft_version_greater(peft_version):
|
304
|
+
"""
|
305
|
+
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
306
|
+
versions of PEFT and transformers.
|
307
|
+
"""
|
308
|
+
|
309
|
+
def decorator(test_case):
|
310
|
+
correct_peft_version = is_peft_available() and version.parse(
|
311
|
+
version.parse(importlib.metadata.version("peft")).base_version
|
312
|
+
) > version.parse(peft_version)
|
313
|
+
return unittest.skipUnless(
|
314
|
+
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
315
|
+
)(test_case)
|
316
|
+
|
317
|
+
return decorator
|
318
|
+
|
319
|
+
|
262
320
|
def deprecate_after_peft_backend(test_case):
|
263
321
|
"""
|
264
322
|
Decorator marking a test that will be skipped after PEFT backend
|
@@ -766,3 +824,143 @@ def disable_full_determinism():
|
|
766
824
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
|
767
825
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
|
768
826
|
torch.use_deterministic_algorithms(False)
|
827
|
+
|
828
|
+
|
829
|
+
# Utils for custom and alternative accelerator devices
|
830
|
+
def _is_torch_fp16_available(device):
|
831
|
+
if not is_torch_available():
|
832
|
+
return False
|
833
|
+
|
834
|
+
import torch
|
835
|
+
|
836
|
+
device = torch.device(device)
|
837
|
+
|
838
|
+
try:
|
839
|
+
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
|
840
|
+
_ = torch.mul(x, x)
|
841
|
+
return True
|
842
|
+
|
843
|
+
except Exception as e:
|
844
|
+
if device.type == "cuda":
|
845
|
+
raise ValueError(
|
846
|
+
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
847
|
+
)
|
848
|
+
|
849
|
+
return False
|
850
|
+
|
851
|
+
|
852
|
+
def _is_torch_fp64_available(device):
|
853
|
+
if not is_torch_available():
|
854
|
+
return False
|
855
|
+
|
856
|
+
import torch
|
857
|
+
|
858
|
+
try:
|
859
|
+
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
|
860
|
+
_ = torch.mul(x, x)
|
861
|
+
return True
|
862
|
+
|
863
|
+
except Exception as e:
|
864
|
+
if device.type == "cuda":
|
865
|
+
raise ValueError(
|
866
|
+
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
|
867
|
+
)
|
868
|
+
|
869
|
+
return False
|
870
|
+
|
871
|
+
|
872
|
+
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
|
873
|
+
if is_torch_available():
|
874
|
+
# Behaviour flags
|
875
|
+
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}
|
876
|
+
|
877
|
+
# Function definitions
|
878
|
+
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
|
879
|
+
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
|
880
|
+
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}
|
881
|
+
|
882
|
+
|
883
|
+
# This dispatches a defined function according to the accelerator from the function definitions.
|
884
|
+
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
|
885
|
+
if device not in dispatch_table:
|
886
|
+
return dispatch_table["default"](*args, **kwargs)
|
887
|
+
|
888
|
+
fn = dispatch_table[device]
|
889
|
+
|
890
|
+
# Some device agnostic functions return values. Need to guard against 'None' instead at
|
891
|
+
# user level
|
892
|
+
if fn is None:
|
893
|
+
return None
|
894
|
+
|
895
|
+
return fn(*args, **kwargs)
|
896
|
+
|
897
|
+
|
898
|
+
# These are callables which automatically dispatch the function specific to the accelerator
|
899
|
+
def backend_manual_seed(device: str, seed: int):
|
900
|
+
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
|
901
|
+
|
902
|
+
|
903
|
+
def backend_empty_cache(device: str):
|
904
|
+
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
|
905
|
+
|
906
|
+
|
907
|
+
def backend_device_count(device: str):
|
908
|
+
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
|
909
|
+
|
910
|
+
|
911
|
+
# These are callables which return boolean behaviour flags and can be used to specify some
|
912
|
+
# device agnostic alternative where the feature is unsupported.
|
913
|
+
def backend_supports_training(device: str):
|
914
|
+
if not is_torch_available():
|
915
|
+
return False
|
916
|
+
|
917
|
+
if device not in BACKEND_SUPPORTS_TRAINING:
|
918
|
+
device = "default"
|
919
|
+
|
920
|
+
return BACKEND_SUPPORTS_TRAINING[device]
|
921
|
+
|
922
|
+
|
923
|
+
# Guard for when Torch is not available
|
924
|
+
if is_torch_available():
|
925
|
+
# Update device function dict mapping
|
926
|
+
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
|
927
|
+
try:
|
928
|
+
# Try to import the function directly
|
929
|
+
spec_fn = getattr(device_spec_module, attribute_name)
|
930
|
+
device_fn_dict[torch_device] = spec_fn
|
931
|
+
except AttributeError as e:
|
932
|
+
# If the function doesn't exist, and there is no default, throw an error
|
933
|
+
if "default" not in device_fn_dict:
|
934
|
+
raise AttributeError(
|
935
|
+
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
|
936
|
+
) from e
|
937
|
+
|
938
|
+
if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
|
939
|
+
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
|
940
|
+
if not Path(device_spec_path).is_file():
|
941
|
+
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
|
942
|
+
|
943
|
+
try:
|
944
|
+
import_name = device_spec_path[: device_spec_path.index(".py")]
|
945
|
+
except ValueError as e:
|
946
|
+
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
|
947
|
+
|
948
|
+
device_spec_module = importlib.import_module(import_name)
|
949
|
+
|
950
|
+
try:
|
951
|
+
device_name = device_spec_module.DEVICE_NAME
|
952
|
+
except AttributeError:
|
953
|
+
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
|
954
|
+
|
955
|
+
if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
|
956
|
+
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
|
957
|
+
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
|
958
|
+
raise ValueError(msg)
|
959
|
+
|
960
|
+
torch_device = device_name
|
961
|
+
|
962
|
+
# Add one entry here for each `BACKEND_*` dictionary.
|
963
|
+
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
|
964
|
+
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
|
965
|
+
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
|
966
|
+
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
|