diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import importlib
|
|
19
19
|
import inspect
|
20
20
|
import os
|
21
21
|
import re
|
22
|
+
import sys
|
22
23
|
import warnings
|
23
24
|
from dataclasses import dataclass
|
24
25
|
from pathlib import Path
|
@@ -29,7 +30,6 @@ import PIL
|
|
29
30
|
import torch
|
30
31
|
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
31
32
|
from packaging import version
|
32
|
-
from PIL import Image
|
33
33
|
from tqdm.auto import tqdm
|
34
34
|
|
35
35
|
import diffusers
|
@@ -55,6 +55,7 @@ from ..utils import (
|
|
55
55
|
is_torch_version,
|
56
56
|
is_transformers_available,
|
57
57
|
logging,
|
58
|
+
numpy_to_pil,
|
58
59
|
)
|
59
60
|
|
60
61
|
|
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
|
200
201
|
# .bin, .safetensors, ...
|
201
202
|
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
202
203
|
# -00001-of-00002
|
203
|
-
transformers_index_format = "\d{5}-of-\d{5}"
|
204
|
+
transformers_index_format = r"\d{5}-of-\d{5}"
|
204
205
|
|
205
206
|
if variant is not None:
|
206
207
|
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
207
208
|
variant_file_re = re.compile(
|
208
|
-
|
209
|
+
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
209
210
|
)
|
210
211
|
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
211
212
|
variant_index_re = re.compile(
|
212
|
-
|
213
|
+
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
213
214
|
)
|
214
215
|
|
215
216
|
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
216
217
|
non_variant_file_re = re.compile(
|
217
|
-
|
218
|
+
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
218
219
|
)
|
219
220
|
# `text_encoder/pytorch_model.bin.index.json`
|
220
|
-
non_variant_index_re = re.compile(
|
221
|
+
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
221
222
|
|
222
223
|
if variant is not None:
|
223
224
|
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
540
541
|
variant (`str`, *optional*):
|
541
542
|
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
542
543
|
"""
|
543
|
-
self.save_config(save_directory)
|
544
|
-
|
545
544
|
model_index_dict = dict(self.config)
|
546
|
-
model_index_dict.pop("_class_name")
|
547
|
-
model_index_dict.pop("_diffusers_version")
|
545
|
+
model_index_dict.pop("_class_name", None)
|
546
|
+
model_index_dict.pop("_diffusers_version", None)
|
548
547
|
model_index_dict.pop("_module", None)
|
549
548
|
|
550
549
|
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
|
|
557
556
|
return True
|
558
557
|
|
559
558
|
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
560
|
-
|
561
559
|
for pipeline_component_name in model_index_dict.keys():
|
562
560
|
sub_model = getattr(self, pipeline_component_name)
|
563
561
|
model_cls = sub_model.__class__
|
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
|
|
571
569
|
save_method_name = None
|
572
570
|
# search for the model's base class in LOADABLE_CLASSES
|
573
571
|
for library_name, library_classes in LOADABLE_CLASSES.items():
|
574
|
-
|
572
|
+
if library_name in sys.modules:
|
573
|
+
library = importlib.import_module(library_name)
|
574
|
+
else:
|
575
|
+
logger.info(
|
576
|
+
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
|
577
|
+
)
|
578
|
+
|
575
579
|
for base_class, save_load_methods in library_classes.items():
|
576
580
|
class_candidate = getattr(library, base_class, None)
|
577
581
|
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
|
|
581
585
|
if save_method_name is not None:
|
582
586
|
break
|
583
587
|
|
588
|
+
if save_method_name is None:
|
589
|
+
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
590
|
+
# make sure that unsaveable components are not tried to be loaded afterward
|
591
|
+
self.register_to_config(**{pipeline_component_name: (None, None)})
|
592
|
+
continue
|
593
|
+
|
584
594
|
save_method = getattr(sub_model, save_method_name)
|
585
595
|
|
586
596
|
# Call the save method with the argument safe_serialization only if it's supported
|
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
596
606
|
|
597
607
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
598
608
|
|
609
|
+
# finally save the config
|
610
|
+
self.save_config(save_directory)
|
611
|
+
|
599
612
|
def to(
|
600
613
|
self,
|
601
614
|
torch_device: Optional[Union[str, torch.device]] = None,
|
@@ -610,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
610
623
|
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
611
624
|
return False
|
612
625
|
|
613
|
-
return hasattr(module, "_hf_hook") and not isinstance(
|
626
|
+
return hasattr(module, "_hf_hook") and not isinstance(
|
627
|
+
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
|
628
|
+
)
|
614
629
|
|
615
630
|
def module_is_offloaded(module):
|
616
631
|
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
@@ -640,7 +655,20 @@ class DiffusionPipeline(ConfigMixin):
|
|
640
655
|
|
641
656
|
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
642
657
|
for module in modules:
|
643
|
-
module
|
658
|
+
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
|
659
|
+
|
660
|
+
if is_loaded_in_8bit and torch_dtype is not None:
|
661
|
+
logger.warning(
|
662
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
|
663
|
+
)
|
664
|
+
|
665
|
+
if is_loaded_in_8bit and torch_device is not None:
|
666
|
+
logger.warning(
|
667
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
module.to(torch_device, torch_dtype)
|
671
|
+
|
644
672
|
if (
|
645
673
|
module.dtype == torch.float16
|
646
674
|
and str(torch_device) in ["cpu"]
|
@@ -874,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
874
902
|
|
875
903
|
config_dict = cls.load_config(cached_folder)
|
876
904
|
|
905
|
+
# pop out "_ignore_files" as it is only needed for download
|
906
|
+
config_dict.pop("_ignore_files", None)
|
907
|
+
|
877
908
|
# 2. Define which model components should load variants
|
878
909
|
# We retrieve the information by matching whether variant
|
879
910
|
# model checkpoints exist in the subfolders
|
@@ -1045,7 +1076,7 @@ class DiffusionPipeline(ConfigMixin):
|
|
1045
1076
|
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
1046
1077
|
if return_cached_folder:
|
1047
1078
|
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
1048
|
-
deprecate("return_cached_folder", "0.17.0", message
|
1079
|
+
deprecate("return_cached_folder", "0.17.0", message)
|
1049
1080
|
return model, cached_folder
|
1050
1081
|
|
1051
1082
|
return model
|
@@ -1191,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
|
|
1191
1222
|
)
|
1192
1223
|
|
1193
1224
|
config_dict = cls._dict_from_json_file(config_file)
|
1225
|
+
|
1226
|
+
ignore_filenames = config_dict.pop("_ignore_files", [])
|
1227
|
+
|
1194
1228
|
# retrieve all folder_names that contain relevant files
|
1195
1229
|
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
1196
1230
|
|
1197
1231
|
filenames = {sibling.rfilename for sibling in info.siblings}
|
1198
1232
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1199
1233
|
|
1234
|
+
# remove ignored filenames
|
1235
|
+
model_filenames = set(model_filenames) - set(ignore_filenames)
|
1236
|
+
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
1237
|
+
|
1200
1238
|
# if the whole pipeline is cached we don't have to ping the Hub
|
1201
1239
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
1202
1240
|
version.parse(__version__).base_version
|
@@ -1357,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
|
|
1357
1395
|
"""
|
1358
1396
|
Convert a numpy image or a batch of images to a PIL image.
|
1359
1397
|
"""
|
1360
|
-
|
1361
|
-
images = images[None, ...]
|
1362
|
-
images = (images * 255).round().astype("uint8")
|
1363
|
-
if images.shape[-1] == 1:
|
1364
|
-
# special case for grayscale (single channel) images
|
1365
|
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
1366
|
-
else:
|
1367
|
-
pil_images = [Image.fromarray(image) for image in images]
|
1368
|
-
|
1369
|
-
return pil_images
|
1398
|
+
return numpy_to_pil(images)
|
1370
1399
|
|
1371
1400
|
def progress_bar(self, iterable=None, total=None):
|
1372
1401
|
if not hasattr(self, "_progress_bar_config"):
|
@@ -31,33 +31,30 @@ from transformers import (
|
|
31
31
|
CLIPVisionModelWithProjection,
|
32
32
|
)
|
33
33
|
|
34
|
-
from
|
34
|
+
from ...models import (
|
35
35
|
AutoencoderKL,
|
36
36
|
ControlNetModel,
|
37
|
+
PriorTransformer,
|
38
|
+
UNet2DConditionModel,
|
39
|
+
)
|
40
|
+
from ...schedulers import (
|
37
41
|
DDIMScheduler,
|
38
42
|
DDPMScheduler,
|
39
43
|
DPMSolverMultistepScheduler,
|
40
44
|
EulerAncestralDiscreteScheduler,
|
41
45
|
EulerDiscreteScheduler,
|
42
46
|
HeunDiscreteScheduler,
|
43
|
-
LDMTextToImagePipeline,
|
44
47
|
LMSDiscreteScheduler,
|
45
48
|
PNDMScheduler,
|
46
|
-
PriorTransformer,
|
47
|
-
StableDiffusionControlNetPipeline,
|
48
|
-
StableDiffusionPipeline,
|
49
|
-
StableUnCLIPImg2ImgPipeline,
|
50
|
-
StableUnCLIPPipeline,
|
51
49
|
UnCLIPScheduler,
|
52
|
-
UNet2DConditionModel,
|
53
50
|
)
|
54
|
-
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
55
|
-
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
56
|
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
57
|
-
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
58
|
-
|
59
51
|
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
60
52
|
from ...utils.import_utils import BACKENDS_MAPPING
|
53
|
+
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
54
|
+
from ..paint_by_example import PaintByExampleImageEncoder
|
55
|
+
from ..pipeline_utils import DiffusionPipeline
|
56
|
+
from .safety_checker import StableDiffusionSafetyChecker
|
57
|
+
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
61
58
|
|
62
59
|
|
63
60
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|
990
987
|
clip_stats_path: Optional[str] = None,
|
991
988
|
controlnet: Optional[bool] = None,
|
992
989
|
load_safety_checker: bool = True,
|
993
|
-
|
990
|
+
pipeline_class: DiffusionPipeline = None,
|
991
|
+
) -> DiffusionPipeline:
|
994
992
|
"""
|
995
993
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
996
994
|
config file.
|
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1018
1016
|
model_type (`str`, *optional*, defaults to `None`):
|
1019
1017
|
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
1020
1018
|
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
1019
|
+
is_img2img (`bool`, *optional*, defaults to `False`):
|
1020
|
+
Whether the model should be loaded as an img2img pipeline.
|
1021
1021
|
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
1022
1022
|
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
1023
1023
|
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1026
1026
|
Whether the attention computation should always be upcasted. This is necessary when running stable
|
1027
1027
|
diffusion 2.1.
|
1028
1028
|
device (`str`, *optional*, defaults to `None`):
|
1029
|
-
The device to use. Pass `None` to determine automatically.
|
1030
|
-
|
1031
|
-
|
1029
|
+
The device to use. Pass `None` to determine automatically.
|
1030
|
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
1031
|
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
1032
1032
|
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1033
1033
|
Whether to load the safety checker or not. Defaults to `True`.
|
1034
|
+
pipeline_class (`str`, *optional*, defaults to `None`):
|
1035
|
+
The pipeline class to use. Pass `None` to determine automatically.
|
1036
|
+
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
1034
1037
|
"""
|
1038
|
+
|
1039
|
+
# import pipelines here to avoid circular import error when using from_ckpt method
|
1040
|
+
from diffusers import (
|
1041
|
+
LDMTextToImagePipeline,
|
1042
|
+
PaintByExamplePipeline,
|
1043
|
+
StableDiffusionControlNetPipeline,
|
1044
|
+
StableDiffusionPipeline,
|
1045
|
+
StableUnCLIPImg2ImgPipeline,
|
1046
|
+
StableUnCLIPPipeline,
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
if pipeline_class is None:
|
1050
|
+
pipeline_class = StableDiffusionPipeline
|
1051
|
+
|
1035
1052
|
if prediction_type == "v-prediction":
|
1036
1053
|
prediction_type = "v_prediction"
|
1037
1054
|
|
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1193
1210
|
requires_safety_checker=False,
|
1194
1211
|
)
|
1195
1212
|
else:
|
1196
|
-
pipe =
|
1213
|
+
pipe = pipeline_class(
|
1197
1214
|
vae=vae,
|
1198
1215
|
text_encoder=text_model,
|
1199
1216
|
tokenizer=tokenizer,
|
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1293
1310
|
feature_extractor=feature_extractor,
|
1294
1311
|
)
|
1295
1312
|
else:
|
1296
|
-
pipe =
|
1313
|
+
pipe = pipeline_class(
|
1297
1314
|
vae=vae,
|
1298
1315
|
text_encoder=text_model,
|
1299
1316
|
tokenizer=tokenizer,
|
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
|
|
1320
1337
|
upcast_attention: Optional[bool] = None,
|
1321
1338
|
device: str = None,
|
1322
1339
|
from_safetensors: bool = False,
|
1323
|
-
) ->
|
1340
|
+
) -> DiffusionPipeline:
|
1324
1341
|
if not is_omegaconf_available():
|
1325
1342
|
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
1326
1343
|
|
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
|
83
83
|
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
84
84
|
... )
|
85
85
|
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
86
|
-
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet,
|
86
|
+
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
87
87
|
... )
|
88
88
|
>>> params["controlnet"] = controlnet_params
|
89
89
|
|
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|
56
56
|
scheduler: Any,
|
57
57
|
max_noise_level: int = 350,
|
58
58
|
):
|
59
|
-
super().__init__(
|
59
|
+
super().__init__(
|
60
|
+
vae=vae,
|
61
|
+
text_encoder=text_encoder,
|
62
|
+
tokenizer=tokenizer,
|
63
|
+
unet=unet,
|
64
|
+
low_res_scheduler=low_res_scheduler,
|
65
|
+
scheduler=scheduler,
|
66
|
+
safety_checker=None,
|
67
|
+
feature_extractor=None,
|
68
|
+
watermarker=None,
|
69
|
+
max_noise_level=max_noise_level,
|
70
|
+
)
|
60
71
|
|
61
72
|
def __call__(
|
62
73
|
self,
|
@@ -20,7 +20,7 @@ from packaging import version
|
|
20
20
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
21
21
|
|
22
22
|
from ...configuration_utils import FrozenDict
|
23
|
-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
23
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
25
|
from ...schedulers import KarrasDiffusionSchedulers
|
26
26
|
from ...utils import (
|
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
|
|
53
53
|
"""
|
54
54
|
|
55
55
|
|
56
|
-
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
56
|
+
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
57
57
|
r"""
|
58
58
|
Pipeline for text-to-image generation using Stable Diffusion.
|
59
59
|
|
60
60
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
61
61
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
62
62
|
|
63
|
+
In addition the pipeline inherits the following loading methods:
|
64
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
65
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
66
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
67
|
+
|
68
|
+
as well as the following saving methods:
|
69
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
70
|
+
|
63
71
|
Args:
|
64
72
|
vae ([`AutoencoderKL`]):
|
65
73
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import inspect
|
16
16
|
import math
|
17
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
17
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
@@ -76,7 +76,7 @@ class AttentionStore:
|
|
76
76
|
|
77
77
|
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
78
78
|
if self.cur_att_layer >= 0 and is_cross:
|
79
|
-
if attn.shape[1] == self.attn_res
|
79
|
+
if attn.shape[1] == np.prod(self.attn_res):
|
80
80
|
self.step_store[place_in_unet].append(attn)
|
81
81
|
|
82
82
|
self.cur_att_layer += 1
|
@@ -98,7 +98,7 @@ class AttentionStore:
|
|
98
98
|
attention_maps = self.get_average_attention()
|
99
99
|
for location in from_where:
|
100
100
|
for item in attention_maps[location]:
|
101
|
-
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
|
101
|
+
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
|
102
102
|
out.append(cross_maps)
|
103
103
|
out = torch.cat(out, dim=0)
|
104
104
|
out = out.sum(0) / out.shape[0]
|
@@ -109,7 +109,7 @@ class AttentionStore:
|
|
109
109
|
self.step_store = self.get_empty_store()
|
110
110
|
self.attention_store = {}
|
111
111
|
|
112
|
-
def __init__(self, attn_res
|
112
|
+
def __init__(self, attn_res):
|
113
113
|
"""
|
114
114
|
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
|
115
115
|
process
|
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
724
724
|
max_iter_to_alter: int = 25,
|
725
725
|
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
|
726
726
|
scale_factor: int = 20,
|
727
|
-
attn_res: int = 16,
|
727
|
+
attn_res: Optional[Tuple[int]] = (16, 16),
|
728
728
|
):
|
729
729
|
r"""
|
730
730
|
Function invoked when calling the pipeline for generation.
|
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
796
796
|
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
|
797
797
|
scale_factor (`int`, *optional*, default to 20):
|
798
798
|
Scale factor that controls the step size of each Attend and Excite update.
|
799
|
-
attn_res (`
|
800
|
-
The resolution of
|
799
|
+
attn_res (`tuple`, *optional*, default computed from width and height):
|
800
|
+
The 2D resolution of the semantic attention map.
|
801
801
|
|
802
802
|
Examples:
|
803
803
|
|
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
870
870
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
871
871
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
872
872
|
|
873
|
-
|
873
|
+
if attn_res is None:
|
874
|
+
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
|
875
|
+
self.attention_store = AttentionStore(attn_res)
|
874
876
|
self.register_attention_control()
|
875
877
|
|
876
878
|
# default config for step size from original repo
|
@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
|
|
118
118
|
timestep_cond: Optional[torch.Tensor] = None,
|
119
119
|
attention_mask: Optional[torch.Tensor] = None,
|
120
120
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
121
|
+
guess_mode: bool = False,
|
121
122
|
return_dict: bool = True,
|
122
123
|
) -> Union[ControlNetOutput, Tuple]:
|
123
124
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
|
|
131
132
|
timestep_cond,
|
132
133
|
attention_mask,
|
133
134
|
cross_attention_kwargs,
|
135
|
+
guess_mode,
|
134
136
|
return_dict,
|
135
137
|
)
|
136
138
|
|
@@ -154,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
154
156
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
155
157
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
156
158
|
|
159
|
+
In addition the pipeline inherits the following loading methods:
|
160
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
161
|
+
|
157
162
|
Args:
|
158
163
|
vae ([`AutoencoderKL`]):
|
159
164
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -244,6 +249,24 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
244
249
|
"""
|
245
250
|
self.vae.disable_slicing()
|
246
251
|
|
252
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
253
|
+
def enable_vae_tiling(self):
|
254
|
+
r"""
|
255
|
+
Enable tiled VAE decoding.
|
256
|
+
|
257
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
258
|
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
259
|
+
"""
|
260
|
+
self.vae.enable_tiling()
|
261
|
+
|
262
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
263
|
+
def disable_vae_tiling(self):
|
264
|
+
r"""
|
265
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
266
|
+
computing decoding in one step.
|
267
|
+
"""
|
268
|
+
self.vae.disable_tiling()
|
269
|
+
|
247
270
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
248
271
|
r"""
|
249
272
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
@@ -627,7 +650,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
627
650
|
)
|
628
651
|
|
629
652
|
def prepare_image(
|
630
|
-
self,
|
653
|
+
self,
|
654
|
+
image,
|
655
|
+
width,
|
656
|
+
height,
|
657
|
+
batch_size,
|
658
|
+
num_images_per_prompt,
|
659
|
+
device,
|
660
|
+
dtype,
|
661
|
+
do_classifier_free_guidance=False,
|
662
|
+
guess_mode=False,
|
631
663
|
):
|
632
664
|
if not isinstance(image, torch.Tensor):
|
633
665
|
if isinstance(image, PIL.Image.Image):
|
@@ -664,7 +696,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
664
696
|
|
665
697
|
image = image.to(device=device, dtype=dtype)
|
666
698
|
|
667
|
-
if do_classifier_free_guidance:
|
699
|
+
if do_classifier_free_guidance and not guess_mode:
|
668
700
|
image = torch.cat([image] * 2)
|
669
701
|
|
670
702
|
return image
|
@@ -747,6 +779,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
747
779
|
callback_steps: int = 1,
|
748
780
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
749
781
|
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
782
|
+
guess_mode: bool = False,
|
750
783
|
):
|
751
784
|
r"""
|
752
785
|
Function invoked when calling the pipeline for generation.
|
@@ -819,6 +852,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
819
852
|
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
820
853
|
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
821
854
|
corresponding scale as a list.
|
855
|
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
856
|
+
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
857
|
+
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
858
|
+
|
822
859
|
Examples:
|
823
860
|
|
824
861
|
Returns:
|
@@ -883,6 +920,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
883
920
|
device=device,
|
884
921
|
dtype=self.controlnet.dtype,
|
885
922
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
923
|
+
guess_mode=guess_mode,
|
886
924
|
)
|
887
925
|
elif isinstance(self.controlnet, MultiControlNetModel):
|
888
926
|
images = []
|
@@ -897,6 +935,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
897
935
|
device=device,
|
898
936
|
dtype=self.controlnet.dtype,
|
899
937
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
938
|
+
guess_mode=guess_mode,
|
900
939
|
)
|
901
940
|
|
902
941
|
images.append(image_)
|
@@ -934,15 +973,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
934
973
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
935
974
|
|
936
975
|
# controlnet(s) inference
|
976
|
+
if guess_mode and do_classifier_free_guidance:
|
977
|
+
# Infer ControlNet only for the conditional batch.
|
978
|
+
controlnet_latent_model_input = latents
|
979
|
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
980
|
+
else:
|
981
|
+
controlnet_latent_model_input = latent_model_input
|
982
|
+
controlnet_prompt_embeds = prompt_embeds
|
983
|
+
|
937
984
|
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
938
|
-
|
985
|
+
controlnet_latent_model_input,
|
939
986
|
t,
|
940
|
-
encoder_hidden_states=
|
987
|
+
encoder_hidden_states=controlnet_prompt_embeds,
|
941
988
|
controlnet_cond=image,
|
942
989
|
conditioning_scale=controlnet_conditioning_scale,
|
990
|
+
guess_mode=guess_mode,
|
943
991
|
return_dict=False,
|
944
992
|
)
|
945
993
|
|
994
|
+
if guess_mode and do_classifier_free_guidance:
|
995
|
+
# Infered ControlNet only for the conditional batch.
|
996
|
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
997
|
+
# add 0 to the unconditional batch to keep it unchanged.
|
998
|
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
999
|
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1000
|
+
|
946
1001
|
# predict the noise residual
|
947
1002
|
noise_pred = self.unet(
|
948
1003
|
latent_model_input,
|
@@ -23,7 +23,7 @@ from packaging import version
|
|
23
23
|
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
24
24
|
|
25
25
|
from ...configuration_utils import FrozenDict
|
26
|
-
from ...loaders import TextualInversionLoaderMixin
|
26
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
27
27
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
28
28
|
from ...schedulers import KarrasDiffusionSchedulers
|
29
29
|
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
@@ -55,13 +55,20 @@ def preprocess(image):
|
|
55
55
|
return image
|
56
56
|
|
57
57
|
|
58
|
-
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
58
|
+
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
59
59
|
r"""
|
60
60
|
Pipeline for text-guided image to image generation using Stable Diffusion.
|
61
61
|
|
62
62
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
63
63
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
64
64
|
|
65
|
+
In addition the pipeline inherits the following loading methods:
|
66
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
67
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
68
|
+
|
69
|
+
as well as the following saving methods:
|
70
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
71
|
+
|
65
72
|
Args:
|
66
73
|
vae ([`AutoencoderKL`]):
|
67
74
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
25
|
from ...image_processor import VaeImageProcessor
|
26
|
-
from ...loaders import TextualInversionLoaderMixin
|
26
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
27
27
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
28
28
|
from ...schedulers import KarrasDiffusionSchedulers
|
29
29
|
from ...utils import (
|
@@ -92,13 +92,21 @@ def preprocess(image):
|
|
92
92
|
return image
|
93
93
|
|
94
94
|
|
95
|
-
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
95
|
+
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
96
96
|
r"""
|
97
97
|
Pipeline for text-guided image to image generation using Stable Diffusion.
|
98
98
|
|
99
99
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
100
100
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
101
101
|
|
102
|
+
In addition the pipeline inherits the following loading methods:
|
103
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
104
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
105
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
106
|
+
|
107
|
+
as well as the following saving methods:
|
108
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
109
|
+
|
102
110
|
Args:
|
103
111
|
vae ([`AutoencoderKL`]):
|
104
112
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|