diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import PIL.Image
|
6
|
+
|
7
|
+
from ...utils import BaseOutput
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class CogView3PipelineOutput(BaseOutput):
|
12
|
+
"""
|
13
|
+
Output class for CogView3 pipelines.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
17
|
+
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
18
|
+
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
19
|
+
"""
|
20
|
+
|
21
|
+
images: Union[List[PIL.Image.Image], np.ndarray]
|
@@ -1,80 +1,86 @@
|
|
1
|
-
from typing import TYPE_CHECKING
|
2
|
-
|
3
|
-
from ...utils import (
|
4
|
-
DIFFUSERS_SLOW_IMPORT,
|
5
|
-
OptionalDependencyNotAvailable,
|
6
|
-
_LazyModule,
|
7
|
-
get_objects_from_module,
|
8
|
-
is_flax_available,
|
9
|
-
is_torch_available,
|
10
|
-
is_transformers_available,
|
11
|
-
)
|
12
|
-
|
13
|
-
|
14
|
-
_dummy_objects = {}
|
15
|
-
_import_structure = {}
|
16
|
-
|
17
|
-
try:
|
18
|
-
if not (is_transformers_available() and is_torch_available()):
|
19
|
-
raise OptionalDependencyNotAvailable()
|
20
|
-
except OptionalDependencyNotAvailable:
|
21
|
-
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
22
|
-
|
23
|
-
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
|
-
else:
|
25
|
-
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
26
|
-
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
27
|
-
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
28
|
-
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
29
|
-
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
30
|
-
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
31
|
-
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
32
|
-
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
from .
|
54
|
-
|
55
|
-
from .
|
56
|
-
from .
|
57
|
-
from .
|
58
|
-
from .
|
59
|
-
from .
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
from
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_flax_available,
|
9
|
+
is_torch_available,
|
10
|
+
is_transformers_available,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
_dummy_objects = {}
|
15
|
+
_import_structure = {}
|
16
|
+
|
17
|
+
try:
|
18
|
+
if not (is_transformers_available() and is_torch_available()):
|
19
|
+
raise OptionalDependencyNotAvailable()
|
20
|
+
except OptionalDependencyNotAvailable:
|
21
|
+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
22
|
+
|
23
|
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
|
+
else:
|
25
|
+
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
26
|
+
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
27
|
+
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
28
|
+
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
29
|
+
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
30
|
+
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
31
|
+
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
32
|
+
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
33
|
+
_import_structure["pipeline_controlnet_union_inpaint_sd_xl"] = ["StableDiffusionXLControlNetUnionInpaintPipeline"]
|
34
|
+
_import_structure["pipeline_controlnet_union_sd_xl"] = ["StableDiffusionXLControlNetUnionPipeline"]
|
35
|
+
_import_structure["pipeline_controlnet_union_sd_xl_img2img"] = ["StableDiffusionXLControlNetUnionImg2ImgPipeline"]
|
36
|
+
try:
|
37
|
+
if not (is_transformers_available() and is_flax_available()):
|
38
|
+
raise OptionalDependencyNotAvailable()
|
39
|
+
except OptionalDependencyNotAvailable:
|
40
|
+
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
41
|
+
|
42
|
+
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
43
|
+
else:
|
44
|
+
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
45
|
+
|
46
|
+
|
47
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
48
|
+
try:
|
49
|
+
if not (is_transformers_available() and is_torch_available()):
|
50
|
+
raise OptionalDependencyNotAvailable()
|
51
|
+
|
52
|
+
except OptionalDependencyNotAvailable:
|
53
|
+
from ...utils.dummy_torch_and_transformers_objects import *
|
54
|
+
else:
|
55
|
+
from .multicontrolnet import MultiControlNetModel
|
56
|
+
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
57
|
+
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
|
58
|
+
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
59
|
+
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
60
|
+
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
61
|
+
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
62
|
+
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
63
|
+
from .pipeline_controlnet_union_inpaint_sd_xl import StableDiffusionXLControlNetUnionInpaintPipeline
|
64
|
+
from .pipeline_controlnet_union_sd_xl import StableDiffusionXLControlNetUnionPipeline
|
65
|
+
from .pipeline_controlnet_union_sd_xl_img2img import StableDiffusionXLControlNetUnionImg2ImgPipeline
|
66
|
+
|
67
|
+
try:
|
68
|
+
if not (is_transformers_available() and is_flax_available()):
|
69
|
+
raise OptionalDependencyNotAvailable()
|
70
|
+
except OptionalDependencyNotAvailable:
|
71
|
+
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
72
|
+
else:
|
73
|
+
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
74
|
+
|
75
|
+
|
76
|
+
else:
|
77
|
+
import sys
|
78
|
+
|
79
|
+
sys.modules[__name__] = _LazyModule(
|
80
|
+
__name__,
|
81
|
+
globals()["__file__"],
|
82
|
+
_import_structure,
|
83
|
+
module_spec=__spec__,
|
84
|
+
)
|
85
|
+
for name, value in _dummy_objects.items():
|
86
|
+
setattr(sys.modules[__name__], name, value)
|
@@ -1,183 +1,12 @@
|
|
1
|
-
import
|
2
|
-
from
|
3
|
-
|
4
|
-
import torch
|
5
|
-
from torch import nn
|
6
|
-
|
7
|
-
from ...models.controlnet import ControlNetModel, ControlNetOutput
|
8
|
-
from ...models.modeling_utils import ModelMixin
|
9
|
-
from ...utils import logging
|
1
|
+
from ...models.controlnets.multicontrolnet import MultiControlNetModel
|
2
|
+
from ...utils import deprecate, logging
|
10
3
|
|
11
4
|
|
12
5
|
logger = logging.get_logger(__name__)
|
13
6
|
|
14
7
|
|
15
|
-
class MultiControlNetModel(
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
compatible with `ControlNetModel`.
|
21
|
-
|
22
|
-
Args:
|
23
|
-
controlnets (`List[ControlNetModel]`):
|
24
|
-
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
25
|
-
`ControlNetModel` as a list.
|
26
|
-
"""
|
27
|
-
|
28
|
-
def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
|
29
|
-
super().__init__()
|
30
|
-
self.nets = nn.ModuleList(controlnets)
|
31
|
-
|
32
|
-
def forward(
|
33
|
-
self,
|
34
|
-
sample: torch.Tensor,
|
35
|
-
timestep: Union[torch.Tensor, float, int],
|
36
|
-
encoder_hidden_states: torch.Tensor,
|
37
|
-
controlnet_cond: List[torch.tensor],
|
38
|
-
conditioning_scale: List[float],
|
39
|
-
class_labels: Optional[torch.Tensor] = None,
|
40
|
-
timestep_cond: Optional[torch.Tensor] = None,
|
41
|
-
attention_mask: Optional[torch.Tensor] = None,
|
42
|
-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
43
|
-
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
44
|
-
guess_mode: bool = False,
|
45
|
-
return_dict: bool = True,
|
46
|
-
) -> Union[ControlNetOutput, Tuple]:
|
47
|
-
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
48
|
-
down_samples, mid_sample = controlnet(
|
49
|
-
sample=sample,
|
50
|
-
timestep=timestep,
|
51
|
-
encoder_hidden_states=encoder_hidden_states,
|
52
|
-
controlnet_cond=image,
|
53
|
-
conditioning_scale=scale,
|
54
|
-
class_labels=class_labels,
|
55
|
-
timestep_cond=timestep_cond,
|
56
|
-
attention_mask=attention_mask,
|
57
|
-
added_cond_kwargs=added_cond_kwargs,
|
58
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
59
|
-
guess_mode=guess_mode,
|
60
|
-
return_dict=return_dict,
|
61
|
-
)
|
62
|
-
|
63
|
-
# merge samples
|
64
|
-
if i == 0:
|
65
|
-
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
66
|
-
else:
|
67
|
-
down_block_res_samples = [
|
68
|
-
samples_prev + samples_curr
|
69
|
-
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
70
|
-
]
|
71
|
-
mid_block_res_sample += mid_sample
|
72
|
-
|
73
|
-
return down_block_res_samples, mid_block_res_sample
|
74
|
-
|
75
|
-
def save_pretrained(
|
76
|
-
self,
|
77
|
-
save_directory: Union[str, os.PathLike],
|
78
|
-
is_main_process: bool = True,
|
79
|
-
save_function: Callable = None,
|
80
|
-
safe_serialization: bool = True,
|
81
|
-
variant: Optional[str] = None,
|
82
|
-
):
|
83
|
-
"""
|
84
|
-
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
85
|
-
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
|
86
|
-
|
87
|
-
Arguments:
|
88
|
-
save_directory (`str` or `os.PathLike`):
|
89
|
-
Directory to which to save. Will be created if it doesn't exist.
|
90
|
-
is_main_process (`bool`, *optional*, defaults to `True`):
|
91
|
-
Whether the process calling this is the main process or not. Useful when in distributed training like
|
92
|
-
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
93
|
-
the main process to avoid race conditions.
|
94
|
-
save_function (`Callable`):
|
95
|
-
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
96
|
-
need to replace `torch.save` by another method. Can be configured with the environment variable
|
97
|
-
`DIFFUSERS_SAVE_MODE`.
|
98
|
-
safe_serialization (`bool`, *optional*, defaults to `True`):
|
99
|
-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
100
|
-
variant (`str`, *optional*):
|
101
|
-
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
102
|
-
"""
|
103
|
-
for idx, controlnet in enumerate(self.nets):
|
104
|
-
suffix = "" if idx == 0 else f"_{idx}"
|
105
|
-
controlnet.save_pretrained(
|
106
|
-
save_directory + suffix,
|
107
|
-
is_main_process=is_main_process,
|
108
|
-
save_function=save_function,
|
109
|
-
safe_serialization=safe_serialization,
|
110
|
-
variant=variant,
|
111
|
-
)
|
112
|
-
|
113
|
-
@classmethod
|
114
|
-
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
|
115
|
-
r"""
|
116
|
-
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
|
117
|
-
|
118
|
-
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
119
|
-
the model, you should first set it back in training mode with `model.train()`.
|
120
|
-
|
121
|
-
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
122
|
-
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
123
|
-
task.
|
124
|
-
|
125
|
-
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
126
|
-
weights are discarded.
|
127
|
-
|
128
|
-
Parameters:
|
129
|
-
pretrained_model_path (`os.PathLike`):
|
130
|
-
A path to a *directory* containing model weights saved using
|
131
|
-
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
|
132
|
-
`./my_model_directory/controlnet`.
|
133
|
-
torch_dtype (`str` or `torch.dtype`, *optional*):
|
134
|
-
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
135
|
-
will be automatically derived from the model's weights.
|
136
|
-
output_loading_info(`bool`, *optional*, defaults to `False`):
|
137
|
-
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
138
|
-
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
139
|
-
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
140
|
-
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
141
|
-
same device.
|
142
|
-
|
143
|
-
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
144
|
-
more information about each option see [designing a device
|
145
|
-
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
146
|
-
max_memory (`Dict`, *optional*):
|
147
|
-
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
148
|
-
GPU and the available CPU RAM if unset.
|
149
|
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
150
|
-
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
151
|
-
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
152
|
-
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
153
|
-
setting this argument to `True` will raise an error.
|
154
|
-
variant (`str`, *optional*):
|
155
|
-
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
156
|
-
ignored when using `from_flax`.
|
157
|
-
use_safetensors (`bool`, *optional*, defaults to `None`):
|
158
|
-
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
159
|
-
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
160
|
-
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
161
|
-
"""
|
162
|
-
idx = 0
|
163
|
-
controlnets = []
|
164
|
-
|
165
|
-
# load controlnet and append to list until no controlnet directory exists anymore
|
166
|
-
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
|
167
|
-
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
|
168
|
-
model_path_to_load = pretrained_model_path
|
169
|
-
while os.path.isdir(model_path_to_load):
|
170
|
-
controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
|
171
|
-
controlnets.append(controlnet)
|
172
|
-
|
173
|
-
idx += 1
|
174
|
-
model_path_to_load = pretrained_model_path + f"_{idx}"
|
175
|
-
|
176
|
-
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
|
177
|
-
|
178
|
-
if len(controlnets) == 0:
|
179
|
-
raise ValueError(
|
180
|
-
f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
|
181
|
-
)
|
182
|
-
|
183
|
-
return cls(controlnets)
|
8
|
+
class MultiControlNetModel(MultiControlNetModel):
|
9
|
+
def __init__(self, *args, **kwargs):
|
10
|
+
deprecation_message = "Importing `MultiControlNetModel` from `diffusers.pipelines.controlnet.multicontrolnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel`, instead."
|
11
|
+
deprecate("diffusers.pipelines.controlnet.multicontrolnet.MultiControlNetModel", "0.34", deprecation_message)
|
12
|
+
super().__init__(*args, **kwargs)
|
@@ -25,12 +25,13 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
25
25
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
26
26
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
27
27
|
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
28
|
-
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
28
|
+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
29
29
|
from ...models.lora import adjust_lora_scale_text_encoder
|
30
30
|
from ...schedulers import KarrasDiffusionSchedulers
|
31
31
|
from ...utils import (
|
32
32
|
USE_PEFT_BACKEND,
|
33
33
|
deprecate,
|
34
|
+
is_torch_xla_available,
|
34
35
|
logging,
|
35
36
|
replace_example_docstring,
|
36
37
|
scale_lora_layers,
|
@@ -40,9 +41,15 @@ from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_ten
|
|
40
41
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
41
42
|
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
42
43
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
43
|
-
from .multicontrolnet import MultiControlNetModel
|
44
44
|
|
45
45
|
|
46
|
+
if is_torch_xla_available():
|
47
|
+
import torch_xla.core.xla_model as xm
|
48
|
+
|
49
|
+
XLA_AVAILABLE = True
|
50
|
+
else:
|
51
|
+
XLA_AVAILABLE = False
|
52
|
+
|
46
53
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47
54
|
|
48
55
|
|
@@ -101,7 +108,7 @@ def retrieve_timesteps(
|
|
101
108
|
sigmas: Optional[List[float]] = None,
|
102
109
|
**kwargs,
|
103
110
|
):
|
104
|
-
"""
|
111
|
+
r"""
|
105
112
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
106
113
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
107
114
|
|
@@ -893,6 +900,10 @@ class StableDiffusionControlNetPipeline(
|
|
893
900
|
def num_timesteps(self):
|
894
901
|
return self._num_timesteps
|
895
902
|
|
903
|
+
@property
|
904
|
+
def interrupt(self):
|
905
|
+
return self._interrupt
|
906
|
+
|
896
907
|
@torch.no_grad()
|
897
908
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
898
909
|
def __call__(
|
@@ -1089,6 +1100,7 @@ class StableDiffusionControlNetPipeline(
|
|
1089
1100
|
self._guidance_scale = guidance_scale
|
1090
1101
|
self._clip_skip = clip_skip
|
1091
1102
|
self._cross_attention_kwargs = cross_attention_kwargs
|
1103
|
+
self._interrupt = False
|
1092
1104
|
|
1093
1105
|
# 2. Define call parameters
|
1094
1106
|
if prompt is not None and isinstance(prompt, str):
|
@@ -1235,6 +1247,9 @@ class StableDiffusionControlNetPipeline(
|
|
1235
1247
|
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
1236
1248
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1237
1249
|
for i, t in enumerate(timesteps):
|
1250
|
+
if self.interrupt:
|
1251
|
+
continue
|
1252
|
+
|
1238
1253
|
# Relevant thread:
|
1239
1254
|
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
1240
1255
|
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
@@ -1316,6 +1331,8 @@ class StableDiffusionControlNetPipeline(
|
|
1316
1331
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
1317
1332
|
callback(step_idx, t, latents)
|
1318
1333
|
|
1334
|
+
if XLA_AVAILABLE:
|
1335
|
+
xm.mark_step()
|
1319
1336
|
# If we do sequential model offloading, let's offload unet and controlnet
|
1320
1337
|
# manually for max memory savings
|
1321
1338
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
24
24
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
25
25
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
26
26
|
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
27
|
-
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
27
|
+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
28
28
|
from ...models.lora import adjust_lora_scale_text_encoder
|
29
29
|
from ...schedulers import KarrasDiffusionSchedulers
|
30
30
|
from ...utils import (
|
@@ -39,7 +39,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
|
39
39
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
40
40
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
41
41
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
42
|
-
from .multicontrolnet import MultiControlNetModel
|
43
42
|
|
44
43
|
|
45
44
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -891,6 +890,10 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
|
891
890
|
def num_timesteps(self):
|
892
891
|
return self._num_timesteps
|
893
892
|
|
893
|
+
@property
|
894
|
+
def interrupt(self):
|
895
|
+
return self._interrupt
|
896
|
+
|
894
897
|
@torch.no_grad()
|
895
898
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
896
899
|
def __call__(
|
@@ -1081,6 +1084,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
|
1081
1084
|
self._guidance_scale = guidance_scale
|
1082
1085
|
self._clip_skip = clip_skip
|
1083
1086
|
self._cross_attention_kwargs = cross_attention_kwargs
|
1087
|
+
self._interrupt = False
|
1084
1088
|
|
1085
1089
|
# 2. Define call parameters
|
1086
1090
|
if prompt is not None and isinstance(prompt, str):
|
@@ -1211,6 +1215,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
|
1211
1215
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1212
1216
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1213
1217
|
for i, t in enumerate(timesteps):
|
1218
|
+
if self.interrupt:
|
1219
|
+
continue
|
1220
|
+
|
1214
1221
|
# expand the latents if we are doing classifier free guidance
|
1215
1222
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1216
1223
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
@@ -26,7 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
26
26
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
27
27
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
28
28
|
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
29
|
-
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
29
|
+
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
|
30
30
|
from ...models.lora import adjust_lora_scale_text_encoder
|
31
31
|
from ...schedulers import KarrasDiffusionSchedulers
|
32
32
|
from ...utils import (
|
@@ -41,7 +41,6 @@ from ...utils.torch_utils import is_compiled_module, randn_tensor
|
|
41
41
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
42
42
|
from ..stable_diffusion import StableDiffusionPipelineOutput
|
43
43
|
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
44
|
-
from .multicontrolnet import MultiControlNetModel
|
45
44
|
|
46
45
|
|
47
46
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -976,6 +975,10 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
976
975
|
def num_timesteps(self):
|
977
976
|
return self._num_timesteps
|
978
977
|
|
978
|
+
@property
|
979
|
+
def interrupt(self):
|
980
|
+
return self._interrupt
|
981
|
+
|
979
982
|
@torch.no_grad()
|
980
983
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
981
984
|
def __call__(
|
@@ -1191,6 +1194,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
1191
1194
|
self._guidance_scale = guidance_scale
|
1192
1195
|
self._clip_skip = clip_skip
|
1193
1196
|
self._cross_attention_kwargs = cross_attention_kwargs
|
1197
|
+
self._interrupt = False
|
1194
1198
|
|
1195
1199
|
# 2. Define call parameters
|
1196
1200
|
if prompt is not None and isinstance(prompt, str):
|
@@ -1375,6 +1379,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
1375
1379
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1376
1380
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1377
1381
|
for i, t in enumerate(timesteps):
|
1382
|
+
if self.interrupt:
|
1383
|
+
continue
|
1384
|
+
|
1378
1385
|
# expand the latents if we are doing classifier free guidance
|
1379
1386
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1380
1387
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|