diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +26 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +33 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +8 -0
- diffusers/models/activations.py +23 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +475 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +363 -32
- diffusers/models/model_loading_utils.py +177 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +175 -99
- diffusers/models/normalization.py +2 -1
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/__init__.py +3 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
- diffusers/models/transformers/pixart_transformer_2d.py +336 -0
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +292 -184
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +27 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +269 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +75 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
- diffusers-0.28.1.dist-info/RECORD +419 -0
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import re
|
|
21
21
|
import sys
|
22
22
|
from dataclasses import dataclass
|
23
23
|
from pathlib import Path
|
24
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
24
|
+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
25
25
|
|
26
26
|
import numpy as np
|
27
27
|
import PIL.Image
|
@@ -43,7 +43,7 @@ from .. import __version__
|
|
43
43
|
from ..configuration_utils import ConfigMixin
|
44
44
|
from ..models import AutoencoderKL
|
45
45
|
from ..models.attention_processor import FusedAttnProcessor2_0
|
46
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
46
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
47
47
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
48
48
|
from ..utils import (
|
49
49
|
CONFIG_NAME,
|
@@ -72,6 +72,8 @@ from .pipeline_loading_utils import (
|
|
72
72
|
CUSTOM_PIPELINE_FILE_NAME,
|
73
73
|
LOADABLE_CLASSES,
|
74
74
|
_fetch_class_library_tuple,
|
75
|
+
_get_custom_pipeline_class,
|
76
|
+
_get_final_device_map,
|
75
77
|
_get_pipeline_class,
|
76
78
|
_unwrap_model,
|
77
79
|
is_safetensors_compatible,
|
@@ -90,6 +92,8 @@ LIBRARIES = []
|
|
90
92
|
for library in LOADABLE_CLASSES:
|
91
93
|
LIBRARIES.append(library)
|
92
94
|
|
95
|
+
SUPPORTED_DEVICE_MAP = ["balanced"]
|
96
|
+
|
93
97
|
logger = logging.get_logger(__name__)
|
94
98
|
|
95
99
|
|
@@ -140,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
140
144
|
|
141
145
|
config_name = "model_index.json"
|
142
146
|
model_cpu_offload_seq = None
|
147
|
+
hf_device_map = None
|
143
148
|
_optional_components = []
|
144
149
|
_exclude_from_cpu_offload = []
|
145
150
|
_load_connected_pipes = False
|
@@ -371,8 +376,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
371
376
|
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
372
377
|
return False
|
373
378
|
|
374
|
-
return hasattr(module, "_hf_hook") and
|
375
|
-
module._hf_hook,
|
379
|
+
return hasattr(module, "_hf_hook") and (
|
380
|
+
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
381
|
+
or hasattr(module._hf_hook, "hooks")
|
382
|
+
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
376
383
|
)
|
377
384
|
|
378
385
|
def module_is_offloaded(module):
|
@@ -390,6 +397,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
390
397
|
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
391
398
|
)
|
392
399
|
|
400
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
401
|
+
if is_pipeline_device_mapped:
|
402
|
+
raise ValueError(
|
403
|
+
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
404
|
+
)
|
405
|
+
|
393
406
|
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
394
407
|
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
395
408
|
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
|
@@ -520,9 +533,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
520
533
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
521
534
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
522
535
|
is not used.
|
523
|
-
resume_download
|
524
|
-
|
525
|
-
|
536
|
+
resume_download:
|
537
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
538
|
+
of Diffusers.
|
526
539
|
proxies (`Dict[str, str]`, *optional*):
|
527
540
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
528
541
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -539,7 +552,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
539
552
|
allowed by Git.
|
540
553
|
custom_revision (`str`, *optional*):
|
541
554
|
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
542
|
-
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
|
555
|
+
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
|
556
|
+
version.
|
543
557
|
mirror (`str`, *optional*):
|
544
558
|
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
545
559
|
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
@@ -611,7 +625,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
611
625
|
```
|
612
626
|
"""
|
613
627
|
cache_dir = kwargs.pop("cache_dir", None)
|
614
|
-
resume_download = kwargs.pop("resume_download",
|
628
|
+
resume_download = kwargs.pop("resume_download", None)
|
615
629
|
force_download = kwargs.pop("force_download", False)
|
616
630
|
proxies = kwargs.pop("proxies", None)
|
617
631
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -642,18 +656,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
642
656
|
" install accelerate\n```\n."
|
643
657
|
)
|
644
658
|
|
659
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
660
|
+
raise NotImplementedError(
|
661
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
662
|
+
" `low_cpu_mem_usage=False`."
|
663
|
+
)
|
664
|
+
|
645
665
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
646
666
|
raise NotImplementedError(
|
647
667
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
648
668
|
" `device_map=None`."
|
649
669
|
)
|
650
670
|
|
651
|
-
if
|
671
|
+
if device_map is not None and not is_accelerate_available():
|
652
672
|
raise NotImplementedError(
|
653
|
-
"
|
654
|
-
" `low_cpu_mem_usage=False`."
|
673
|
+
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
655
674
|
)
|
656
675
|
|
676
|
+
if device_map is not None and not isinstance(device_map, str):
|
677
|
+
raise ValueError("`device_map` must be a string.")
|
678
|
+
|
679
|
+
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
680
|
+
raise NotImplementedError(
|
681
|
+
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
682
|
+
)
|
683
|
+
|
684
|
+
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
685
|
+
if is_accelerate_version("<", "0.28.0"):
|
686
|
+
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
687
|
+
|
657
688
|
if low_cpu_mem_usage is False and device_map is not None:
|
658
689
|
raise ValueError(
|
659
690
|
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
@@ -729,6 +760,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
729
760
|
revision=custom_revision,
|
730
761
|
)
|
731
762
|
|
763
|
+
if device_map is not None and pipeline_class._load_connected_pipes:
|
764
|
+
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
765
|
+
|
732
766
|
# DEPRECATED: To be removed in 1.0.0
|
733
767
|
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
734
768
|
version.parse(config_dict["_diffusers_version"]).base_version
|
@@ -795,17 +829,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
795
829
|
# import it here to avoid circular import
|
796
830
|
from diffusers import pipelines
|
797
831
|
|
798
|
-
# 6.
|
832
|
+
# 6. device map delegation
|
833
|
+
final_device_map = None
|
834
|
+
if device_map is not None:
|
835
|
+
final_device_map = _get_final_device_map(
|
836
|
+
device_map=device_map,
|
837
|
+
pipeline_class=pipeline_class,
|
838
|
+
passed_class_obj=passed_class_obj,
|
839
|
+
init_dict=init_dict,
|
840
|
+
library=library,
|
841
|
+
max_memory=max_memory,
|
842
|
+
torch_dtype=torch_dtype,
|
843
|
+
cached_folder=cached_folder,
|
844
|
+
force_download=force_download,
|
845
|
+
resume_download=resume_download,
|
846
|
+
proxies=proxies,
|
847
|
+
local_files_only=local_files_only,
|
848
|
+
token=token,
|
849
|
+
revision=revision,
|
850
|
+
)
|
851
|
+
|
852
|
+
# 7. Load each module in the pipeline
|
853
|
+
current_device_map = None
|
799
854
|
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
800
|
-
|
855
|
+
if final_device_map is not None and len(final_device_map) > 0:
|
856
|
+
component_device = final_device_map.get(name, None)
|
857
|
+
if component_device is not None:
|
858
|
+
current_device_map = {"": component_device}
|
859
|
+
else:
|
860
|
+
current_device_map = None
|
861
|
+
|
862
|
+
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
801
863
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
802
864
|
|
803
|
-
#
|
865
|
+
# 7.2 Define all importable classes
|
804
866
|
is_pipeline_module = hasattr(pipelines, library_name)
|
805
867
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
806
868
|
loaded_sub_model = None
|
807
869
|
|
808
|
-
#
|
870
|
+
# 7.3 Use passed sub model or load class_name from library_name
|
809
871
|
if name in passed_class_obj:
|
810
872
|
# if the model is in a pipeline module, then we load it from the pipeline
|
811
873
|
# check that passed_class_obj has correct parent class
|
@@ -826,7 +888,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
826
888
|
torch_dtype=torch_dtype,
|
827
889
|
provider=provider,
|
828
890
|
sess_options=sess_options,
|
829
|
-
device_map=
|
891
|
+
device_map=current_device_map,
|
830
892
|
max_memory=max_memory,
|
831
893
|
offload_folder=offload_folder,
|
832
894
|
offload_state_dict=offload_state_dict,
|
@@ -893,7 +955,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
893
955
|
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
894
956
|
)
|
895
957
|
|
896
|
-
#
|
958
|
+
# 8. Potentially add passed objects if expected
|
897
959
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
898
960
|
passed_modules = list(passed_class_obj.keys())
|
899
961
|
optional_modules = pipeline_class._optional_components
|
@@ -906,11 +968,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
906
968
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
907
969
|
)
|
908
970
|
|
909
|
-
#
|
971
|
+
# 10. Instantiate the pipeline
|
910
972
|
model = pipeline_class(**init_kwargs)
|
911
973
|
|
912
|
-
#
|
974
|
+
# 11. Save where the model was instantiated from
|
913
975
|
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
976
|
+
if device_map is not None:
|
977
|
+
setattr(model, "hf_device_map", final_device_map)
|
914
978
|
return model
|
915
979
|
|
916
980
|
@property
|
@@ -939,6 +1003,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
939
1003
|
return torch.device(module._hf_hook.execution_device)
|
940
1004
|
return self.device
|
941
1005
|
|
1006
|
+
def remove_all_hooks(self):
|
1007
|
+
r"""
|
1008
|
+
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
|
1009
|
+
"""
|
1010
|
+
for _, model in self.components.items():
|
1011
|
+
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
1012
|
+
accelerate.hooks.remove_hook_from_module(model, recurse=True)
|
1013
|
+
self._all_hooks = []
|
1014
|
+
|
942
1015
|
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
943
1016
|
r"""
|
944
1017
|
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
@@ -953,6 +1026,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
953
1026
|
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
954
1027
|
default to "cuda".
|
955
1028
|
"""
|
1029
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
1030
|
+
if is_pipeline_device_mapped:
|
1031
|
+
raise ValueError(
|
1032
|
+
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
1033
|
+
)
|
1034
|
+
|
956
1035
|
if self.model_cpu_offload_seq is None:
|
957
1036
|
raise ValueError(
|
958
1037
|
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
@@ -963,6 +1042,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
963
1042
|
else:
|
964
1043
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
965
1044
|
|
1045
|
+
self.remove_all_hooks()
|
1046
|
+
|
966
1047
|
torch_device = torch.device(device)
|
967
1048
|
device_index = torch_device.index
|
968
1049
|
|
@@ -979,11 +1060,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
979
1060
|
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
980
1061
|
self._offload_device = device
|
981
1062
|
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
1063
|
+
self.to("cpu", silence_dtype_warnings=True)
|
1064
|
+
device_mod = getattr(torch, device.type, None)
|
1065
|
+
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
1066
|
+
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
987
1067
|
|
988
1068
|
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
989
1069
|
|
@@ -1021,11 +1101,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1021
1101
|
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
1022
1102
|
return
|
1023
1103
|
|
1024
|
-
for hook in self._all_hooks:
|
1025
|
-
# offload model and remove hook from model
|
1026
|
-
hook.offload()
|
1027
|
-
hook.remove()
|
1028
|
-
|
1029
1104
|
# make sure the model is in the same state as before calling it
|
1030
1105
|
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
1031
1106
|
|
@@ -1048,6 +1123,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1048
1123
|
from accelerate import cpu_offload
|
1049
1124
|
else:
|
1050
1125
|
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
1126
|
+
self.remove_all_hooks()
|
1127
|
+
|
1128
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
1129
|
+
if is_pipeline_device_mapped:
|
1130
|
+
raise ValueError(
|
1131
|
+
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
1132
|
+
)
|
1051
1133
|
|
1052
1134
|
torch_device = torch.device(device)
|
1053
1135
|
device_index = torch_device.index
|
@@ -1083,6 +1165,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1083
1165
|
offload_buffers = len(model._parameters) > 0
|
1084
1166
|
cpu_offload(model, device, offload_buffers=offload_buffers)
|
1085
1167
|
|
1168
|
+
def reset_device_map(self):
|
1169
|
+
r"""
|
1170
|
+
Resets the device maps (if any) to None.
|
1171
|
+
"""
|
1172
|
+
if self.hf_device_map is None:
|
1173
|
+
return
|
1174
|
+
else:
|
1175
|
+
self.remove_all_hooks()
|
1176
|
+
for name, component in self.components.items():
|
1177
|
+
if isinstance(component, torch.nn.Module):
|
1178
|
+
component.to("cpu")
|
1179
|
+
self.hf_device_map = None
|
1180
|
+
|
1086
1181
|
@classmethod
|
1087
1182
|
@validate_hf_hub_args
|
1088
1183
|
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
@@ -1121,9 +1216,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1121
1216
|
force_download (`bool`, *optional*, defaults to `False`):
|
1122
1217
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1123
1218
|
cached versions if they exist.
|
1124
|
-
resume_download
|
1125
|
-
|
1126
|
-
|
1219
|
+
resume_download:
|
1220
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
1221
|
+
of Diffusers.
|
1127
1222
|
proxies (`Dict[str, str]`, *optional*):
|
1128
1223
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1129
1224
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -1176,7 +1271,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1176
1271
|
|
1177
1272
|
"""
|
1178
1273
|
cache_dir = kwargs.pop("cache_dir", None)
|
1179
|
-
resume_download = kwargs.pop("resume_download",
|
1274
|
+
resume_download = kwargs.pop("resume_download", None)
|
1180
1275
|
force_download = kwargs.pop("force_download", False)
|
1181
1276
|
proxies = kwargs.pop("proxies", None)
|
1182
1277
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -1382,7 +1477,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1382
1477
|
|
1383
1478
|
# Don't download index files of forbidden patterns either
|
1384
1479
|
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
|
1385
|
-
|
1386
1480
|
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
|
1387
1481
|
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
|
1388
1482
|
|
@@ -1472,6 +1566,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1472
1566
|
|
1473
1567
|
return expected_modules, optional_parameters
|
1474
1568
|
|
1569
|
+
@classmethod
|
1570
|
+
def _get_signature_types(cls):
|
1571
|
+
signature_types = {}
|
1572
|
+
for k, v in inspect.signature(cls.__init__).parameters.items():
|
1573
|
+
if inspect.isclass(v.annotation):
|
1574
|
+
signature_types[k] = (v.annotation,)
|
1575
|
+
elif get_origin(v.annotation) == Union:
|
1576
|
+
signature_types[k] = get_args(v.annotation)
|
1577
|
+
else:
|
1578
|
+
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
|
1579
|
+
return signature_types
|
1580
|
+
|
1475
1581
|
@property
|
1476
1582
|
def components(self) -> Dict[str, Any]:
|
1477
1583
|
r"""
|
@@ -1650,6 +1756,129 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1650
1756
|
for module in modules:
|
1651
1757
|
module.set_attention_slice(slice_size)
|
1652
1758
|
|
1759
|
+
@classmethod
|
1760
|
+
def from_pipe(cls, pipeline, **kwargs):
|
1761
|
+
r"""
|
1762
|
+
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
|
1763
|
+
pipeline components without reallocating additional memory.
|
1764
|
+
|
1765
|
+
Arguments:
|
1766
|
+
pipeline (`DiffusionPipeline`):
|
1767
|
+
The pipeline from which to create a new pipeline.
|
1768
|
+
|
1769
|
+
Returns:
|
1770
|
+
`DiffusionPipeline`:
|
1771
|
+
A new pipeline with the same weights and configurations as `pipeline`.
|
1772
|
+
|
1773
|
+
Examples:
|
1774
|
+
|
1775
|
+
```py
|
1776
|
+
>>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
|
1777
|
+
|
1778
|
+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
1779
|
+
>>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
|
1780
|
+
```
|
1781
|
+
"""
|
1782
|
+
|
1783
|
+
original_config = dict(pipeline.config)
|
1784
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1785
|
+
|
1786
|
+
# derive the pipeline class to instantiate
|
1787
|
+
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
1788
|
+
custom_revision = kwargs.pop("custom_revision", None)
|
1789
|
+
|
1790
|
+
if custom_pipeline is not None:
|
1791
|
+
pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
|
1792
|
+
else:
|
1793
|
+
pipeline_class = cls
|
1794
|
+
|
1795
|
+
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
1796
|
+
# true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
|
1797
|
+
# e.g. `image_encoder` for StableDiffusionPipeline
|
1798
|
+
parameters = inspect.signature(cls.__init__).parameters
|
1799
|
+
true_optional_modules = set(
|
1800
|
+
{k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
|
1801
|
+
)
|
1802
|
+
|
1803
|
+
# get the class of each component based on its type hint
|
1804
|
+
# e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
|
1805
|
+
component_types = pipeline_class._get_signature_types()
|
1806
|
+
|
1807
|
+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
1808
|
+
# allow users pass modules in `kwargs` to override the original pipeline's components
|
1809
|
+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
1810
|
+
|
1811
|
+
original_class_obj = {}
|
1812
|
+
for name, component in pipeline.components.items():
|
1813
|
+
if name in expected_modules and name not in passed_class_obj:
|
1814
|
+
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
|
1815
|
+
if (
|
1816
|
+
not isinstance(component, ModelMixin)
|
1817
|
+
or type(component) in component_types[name]
|
1818
|
+
or (component is None and name in cls._optional_components)
|
1819
|
+
):
|
1820
|
+
original_class_obj[name] = component
|
1821
|
+
else:
|
1822
|
+
logger.warning(
|
1823
|
+
f"component {name} is not switched over to new pipeline because type does not match the expected."
|
1824
|
+
f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
|
1825
|
+
f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
|
1826
|
+
)
|
1827
|
+
|
1828
|
+
# allow users pass optional kwargs to override the original pipelines config attribute
|
1829
|
+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
1830
|
+
original_pipe_kwargs = {
|
1831
|
+
k: original_config[k]
|
1832
|
+
for k in original_config.keys()
|
1833
|
+
if k in optional_kwargs and k not in passed_pipe_kwargs
|
1834
|
+
}
|
1835
|
+
|
1836
|
+
# config attribute that were not expected by pipeline is stored as its private attribute
|
1837
|
+
# (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
|
1838
|
+
# in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
|
1839
|
+
additional_pipe_kwargs = [
|
1840
|
+
k[1:]
|
1841
|
+
for k in original_config.keys()
|
1842
|
+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
|
1843
|
+
]
|
1844
|
+
for k in additional_pipe_kwargs:
|
1845
|
+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
|
1846
|
+
|
1847
|
+
pipeline_kwargs = {
|
1848
|
+
**passed_class_obj,
|
1849
|
+
**original_class_obj,
|
1850
|
+
**passed_pipe_kwargs,
|
1851
|
+
**original_pipe_kwargs,
|
1852
|
+
**kwargs,
|
1853
|
+
}
|
1854
|
+
|
1855
|
+
# store unused config as private attribute in the new pipeline
|
1856
|
+
unused_original_config = {
|
1857
|
+
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
|
1858
|
+
}
|
1859
|
+
|
1860
|
+
missing_modules = (
|
1861
|
+
set(expected_modules)
|
1862
|
+
- set(pipeline._optional_components)
|
1863
|
+
- set(pipeline_kwargs.keys())
|
1864
|
+
- set(true_optional_modules)
|
1865
|
+
)
|
1866
|
+
|
1867
|
+
if len(missing_modules) > 0:
|
1868
|
+
raise ValueError(
|
1869
|
+
f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
|
1870
|
+
)
|
1871
|
+
|
1872
|
+
new_pipeline = pipeline_class(**pipeline_kwargs)
|
1873
|
+
if pretrained_model_name_or_path is not None:
|
1874
|
+
new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
1875
|
+
new_pipeline.register_to_config(**unused_original_config)
|
1876
|
+
|
1877
|
+
if torch_dtype is not None:
|
1878
|
+
new_pipeline.to(dtype=torch_dtype)
|
1879
|
+
|
1880
|
+
return new_pipeline
|
1881
|
+
|
1653
1882
|
|
1654
1883
|
class StableDiffusionMixin:
|
1655
1884
|
r"""
|
@@ -1713,8 +1942,8 @@ class StableDiffusionMixin:
|
|
1713
1942
|
|
1714
1943
|
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
1715
1944
|
"""
|
1716
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1717
|
-
|
1945
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
1946
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
1718
1947
|
|
1719
1948
|
<Tip warning={true}>
|
1720
1949
|
|
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
|
23
23
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
24
|
else:
|
25
25
|
_import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
|
26
|
+
_import_structure["pipeline_pixart_sigma"] = ["PixArtSigmaPipeline"]
|
26
27
|
|
27
28
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
28
29
|
try:
|
@@ -32,7 +33,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
32
33
|
except OptionalDependencyNotAvailable:
|
33
34
|
from ...utils.dummy_torch_and_transformers_objects import *
|
34
35
|
else:
|
35
|
-
from .pipeline_pixart_alpha import
|
36
|
+
from .pipeline_pixart_alpha import (
|
37
|
+
ASPECT_RATIO_256_BIN,
|
38
|
+
ASPECT_RATIO_512_BIN,
|
39
|
+
ASPECT_RATIO_1024_BIN,
|
40
|
+
PixArtAlphaPipeline,
|
41
|
+
)
|
42
|
+
from .pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN, PixArtSigmaPipeline
|
36
43
|
|
37
44
|
else:
|
38
45
|
import sys
|