diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import importlib
|
|
19
19
|
import inspect
|
20
20
|
import os
|
21
21
|
import re
|
22
|
+
import sys
|
22
23
|
import warnings
|
23
24
|
from dataclasses import dataclass
|
24
25
|
from pathlib import Path
|
@@ -29,7 +30,6 @@ import PIL
|
|
29
30
|
import torch
|
30
31
|
from huggingface_hub import hf_hub_download, model_info, snapshot_download
|
31
32
|
from packaging import version
|
32
|
-
from PIL import Image
|
33
33
|
from tqdm.auto import tqdm
|
34
34
|
|
35
35
|
import diffusers
|
@@ -55,6 +55,7 @@ from ..utils import (
|
|
55
55
|
is_torch_version,
|
56
56
|
is_transformers_available,
|
57
57
|
logging,
|
58
|
+
numpy_to_pil,
|
58
59
|
)
|
59
60
|
|
60
61
|
|
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
|
|
200
201
|
# .bin, .safetensors, ...
|
201
202
|
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
202
203
|
# -00001-of-00002
|
203
|
-
transformers_index_format = "\d{5}-of-\d{5}"
|
204
|
+
transformers_index_format = r"\d{5}-of-\d{5}"
|
204
205
|
|
205
206
|
if variant is not None:
|
206
207
|
# `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
|
207
208
|
variant_file_re = re.compile(
|
208
|
-
|
209
|
+
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
209
210
|
)
|
210
211
|
# `text_encoder/pytorch_model.bin.index.fp16.json`
|
211
212
|
variant_index_re = re.compile(
|
212
|
-
|
213
|
+
rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
|
213
214
|
)
|
214
215
|
|
215
216
|
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
|
216
217
|
non_variant_file_re = re.compile(
|
217
|
-
|
218
|
+
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
218
219
|
)
|
219
220
|
# `text_encoder/pytorch_model.bin.index.json`
|
220
|
-
non_variant_index_re = re.compile(
|
221
|
+
non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
|
221
222
|
|
222
223
|
if variant is not None:
|
223
224
|
variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
|
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
540
541
|
variant (`str`, *optional*):
|
541
542
|
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
542
543
|
"""
|
543
|
-
self.save_config(save_directory)
|
544
|
-
|
545
544
|
model_index_dict = dict(self.config)
|
546
|
-
model_index_dict.pop("_class_name")
|
547
|
-
model_index_dict.pop("_diffusers_version")
|
545
|
+
model_index_dict.pop("_class_name", None)
|
546
|
+
model_index_dict.pop("_diffusers_version", None)
|
548
547
|
model_index_dict.pop("_module", None)
|
549
548
|
|
550
549
|
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
|
|
557
556
|
return True
|
558
557
|
|
559
558
|
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
560
|
-
|
561
559
|
for pipeline_component_name in model_index_dict.keys():
|
562
560
|
sub_model = getattr(self, pipeline_component_name)
|
563
561
|
model_cls = sub_model.__class__
|
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
|
|
571
569
|
save_method_name = None
|
572
570
|
# search for the model's base class in LOADABLE_CLASSES
|
573
571
|
for library_name, library_classes in LOADABLE_CLASSES.items():
|
574
|
-
|
572
|
+
if library_name in sys.modules:
|
573
|
+
library = importlib.import_module(library_name)
|
574
|
+
else:
|
575
|
+
logger.info(
|
576
|
+
f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
|
577
|
+
)
|
578
|
+
|
575
579
|
for base_class, save_load_methods in library_classes.items():
|
576
580
|
class_candidate = getattr(library, base_class, None)
|
577
581
|
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
|
|
581
585
|
if save_method_name is not None:
|
582
586
|
break
|
583
587
|
|
588
|
+
if save_method_name is None:
|
589
|
+
logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
590
|
+
# make sure that unsaveable components are not tried to be loaded afterward
|
591
|
+
self.register_to_config(**{pipeline_component_name: (None, None)})
|
592
|
+
continue
|
593
|
+
|
584
594
|
save_method = getattr(sub_model, save_method_name)
|
585
595
|
|
586
596
|
# Call the save method with the argument safe_serialization only if it's supported
|
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
596
606
|
|
597
607
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
598
608
|
|
609
|
+
# finally save the config
|
610
|
+
self.save_config(save_directory)
|
611
|
+
|
599
612
|
def to(
|
600
613
|
self,
|
601
614
|
torch_device: Optional[Union[str, torch.device]] = None,
|
@@ -610,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
610
623
|
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
611
624
|
return False
|
612
625
|
|
613
|
-
return hasattr(module, "_hf_hook") and not isinstance(
|
626
|
+
return hasattr(module, "_hf_hook") and not isinstance(
|
627
|
+
module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
|
628
|
+
)
|
614
629
|
|
615
630
|
def module_is_offloaded(module):
|
616
631
|
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
|
@@ -640,7 +655,20 @@ class DiffusionPipeline(ConfigMixin):
|
|
640
655
|
|
641
656
|
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
642
657
|
for module in modules:
|
643
|
-
module
|
658
|
+
is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
|
659
|
+
|
660
|
+
if is_loaded_in_8bit and torch_dtype is not None:
|
661
|
+
logger.warning(
|
662
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
|
663
|
+
)
|
664
|
+
|
665
|
+
if is_loaded_in_8bit and torch_device is not None:
|
666
|
+
logger.warning(
|
667
|
+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
|
668
|
+
)
|
669
|
+
else:
|
670
|
+
module.to(torch_device, torch_dtype)
|
671
|
+
|
644
672
|
if (
|
645
673
|
module.dtype == torch.float16
|
646
674
|
and str(torch_device) in ["cpu"]
|
@@ -874,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
|
|
874
902
|
|
875
903
|
config_dict = cls.load_config(cached_folder)
|
876
904
|
|
905
|
+
# pop out "_ignore_files" as it is only needed for download
|
906
|
+
config_dict.pop("_ignore_files", None)
|
907
|
+
|
877
908
|
# 2. Define which model components should load variants
|
878
909
|
# We retrieve the information by matching whether variant
|
879
910
|
# model checkpoints exist in the subfolders
|
@@ -1045,7 +1076,7 @@ class DiffusionPipeline(ConfigMixin):
|
|
1045
1076
|
return_cached_folder = kwargs.pop("return_cached_folder", False)
|
1046
1077
|
if return_cached_folder:
|
1047
1078
|
message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
|
1048
|
-
deprecate("return_cached_folder", "0.17.0", message
|
1079
|
+
deprecate("return_cached_folder", "0.17.0", message)
|
1049
1080
|
return model, cached_folder
|
1050
1081
|
|
1051
1082
|
return model
|
@@ -1191,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
|
|
1191
1222
|
)
|
1192
1223
|
|
1193
1224
|
config_dict = cls._dict_from_json_file(config_file)
|
1225
|
+
|
1226
|
+
ignore_filenames = config_dict.pop("_ignore_files", [])
|
1227
|
+
|
1194
1228
|
# retrieve all folder_names that contain relevant files
|
1195
1229
|
folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
|
1196
1230
|
|
1197
1231
|
filenames = {sibling.rfilename for sibling in info.siblings}
|
1198
1232
|
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1199
1233
|
|
1234
|
+
# remove ignored filenames
|
1235
|
+
model_filenames = set(model_filenames) - set(ignore_filenames)
|
1236
|
+
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
1237
|
+
|
1200
1238
|
# if the whole pipeline is cached we don't have to ping the Hub
|
1201
1239
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
1202
1240
|
version.parse(__version__).base_version
|
@@ -1357,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
|
|
1357
1395
|
"""
|
1358
1396
|
Convert a numpy image or a batch of images to a PIL image.
|
1359
1397
|
"""
|
1360
|
-
|
1361
|
-
images = images[None, ...]
|
1362
|
-
images = (images * 255).round().astype("uint8")
|
1363
|
-
if images.shape[-1] == 1:
|
1364
|
-
# special case for grayscale (single channel) images
|
1365
|
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
1366
|
-
else:
|
1367
|
-
pil_images = [Image.fromarray(image) for image in images]
|
1368
|
-
|
1369
|
-
return pil_images
|
1398
|
+
return numpy_to_pil(images)
|
1370
1399
|
|
1371
1400
|
def progress_bar(self, iterable=None, total=None):
|
1372
1401
|
if not hasattr(self, "_progress_bar_config"):
|
@@ -31,33 +31,30 @@ from transformers import (
|
|
31
31
|
CLIPVisionModelWithProjection,
|
32
32
|
)
|
33
33
|
|
34
|
-
from
|
34
|
+
from ...models import (
|
35
35
|
AutoencoderKL,
|
36
36
|
ControlNetModel,
|
37
|
+
PriorTransformer,
|
38
|
+
UNet2DConditionModel,
|
39
|
+
)
|
40
|
+
from ...schedulers import (
|
37
41
|
DDIMScheduler,
|
38
42
|
DDPMScheduler,
|
39
43
|
DPMSolverMultistepScheduler,
|
40
44
|
EulerAncestralDiscreteScheduler,
|
41
45
|
EulerDiscreteScheduler,
|
42
46
|
HeunDiscreteScheduler,
|
43
|
-
LDMTextToImagePipeline,
|
44
47
|
LMSDiscreteScheduler,
|
45
48
|
PNDMScheduler,
|
46
|
-
PriorTransformer,
|
47
|
-
StableDiffusionControlNetPipeline,
|
48
|
-
StableDiffusionPipeline,
|
49
|
-
StableUnCLIPImg2ImgPipeline,
|
50
|
-
StableUnCLIPPipeline,
|
51
49
|
UnCLIPScheduler,
|
52
|
-
UNet2DConditionModel,
|
53
50
|
)
|
54
|
-
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
55
|
-
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
|
56
|
-
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
57
|
-
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
58
|
-
|
59
51
|
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
60
52
|
from ...utils.import_utils import BACKENDS_MAPPING
|
53
|
+
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
54
|
+
from ..paint_by_example import PaintByExampleImageEncoder
|
55
|
+
from ..pipeline_utils import DiffusionPipeline
|
56
|
+
from .safety_checker import StableDiffusionSafetyChecker
|
57
|
+
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
61
58
|
|
62
59
|
|
63
60
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|
990
987
|
clip_stats_path: Optional[str] = None,
|
991
988
|
controlnet: Optional[bool] = None,
|
992
989
|
load_safety_checker: bool = True,
|
993
|
-
|
990
|
+
pipeline_class: DiffusionPipeline = None,
|
991
|
+
) -> DiffusionPipeline:
|
994
992
|
"""
|
995
993
|
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
996
994
|
config file.
|
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1018
1016
|
model_type (`str`, *optional*, defaults to `None`):
|
1019
1017
|
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
1020
1018
|
"FrozenCLIPEmbedder", "PaintByExample"]`.
|
1019
|
+
is_img2img (`bool`, *optional*, defaults to `False`):
|
1020
|
+
Whether the model should be loaded as an img2img pipeline.
|
1021
1021
|
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
1022
1022
|
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
1023
1023
|
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1026
1026
|
Whether the attention computation should always be upcasted. This is necessary when running stable
|
1027
1027
|
diffusion 2.1.
|
1028
1028
|
device (`str`, *optional*, defaults to `None`):
|
1029
|
-
The device to use. Pass `None` to determine automatically.
|
1030
|
-
|
1031
|
-
|
1029
|
+
The device to use. Pass `None` to determine automatically.
|
1030
|
+
from_safetensors (`str`, *optional*, defaults to `False`):
|
1031
|
+
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
|
1032
1032
|
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1033
1033
|
Whether to load the safety checker or not. Defaults to `True`.
|
1034
|
+
pipeline_class (`str`, *optional*, defaults to `None`):
|
1035
|
+
The pipeline class to use. Pass `None` to determine automatically.
|
1036
|
+
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
1034
1037
|
"""
|
1038
|
+
|
1039
|
+
# import pipelines here to avoid circular import error when using from_ckpt method
|
1040
|
+
from diffusers import (
|
1041
|
+
LDMTextToImagePipeline,
|
1042
|
+
PaintByExamplePipeline,
|
1043
|
+
StableDiffusionControlNetPipeline,
|
1044
|
+
StableDiffusionPipeline,
|
1045
|
+
StableUnCLIPImg2ImgPipeline,
|
1046
|
+
StableUnCLIPPipeline,
|
1047
|
+
)
|
1048
|
+
|
1049
|
+
if pipeline_class is None:
|
1050
|
+
pipeline_class = StableDiffusionPipeline
|
1051
|
+
|
1035
1052
|
if prediction_type == "v-prediction":
|
1036
1053
|
prediction_type = "v_prediction"
|
1037
1054
|
|
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1193
1210
|
requires_safety_checker=False,
|
1194
1211
|
)
|
1195
1212
|
else:
|
1196
|
-
pipe =
|
1213
|
+
pipe = pipeline_class(
|
1197
1214
|
vae=vae,
|
1198
1215
|
text_encoder=text_model,
|
1199
1216
|
tokenizer=tokenizer,
|
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|
1293
1310
|
feature_extractor=feature_extractor,
|
1294
1311
|
)
|
1295
1312
|
else:
|
1296
|
-
pipe =
|
1313
|
+
pipe = pipeline_class(
|
1297
1314
|
vae=vae,
|
1298
1315
|
text_encoder=text_model,
|
1299
1316
|
tokenizer=tokenizer,
|
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
|
|
1320
1337
|
upcast_attention: Optional[bool] = None,
|
1321
1338
|
device: str = None,
|
1322
1339
|
from_safetensors: bool = False,
|
1323
|
-
) ->
|
1340
|
+
) -> DiffusionPipeline:
|
1324
1341
|
if not is_omegaconf_available():
|
1325
1342
|
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
1326
1343
|
|
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
|
|
83
83
|
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
|
84
84
|
... )
|
85
85
|
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
86
|
-
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet,
|
86
|
+
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
|
87
87
|
... )
|
88
88
|
>>> params["controlnet"] = controlnet_params
|
89
89
|
|
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
|
|
56
56
|
scheduler: Any,
|
57
57
|
max_noise_level: int = 350,
|
58
58
|
):
|
59
|
-
super().__init__(
|
59
|
+
super().__init__(
|
60
|
+
vae=vae,
|
61
|
+
text_encoder=text_encoder,
|
62
|
+
tokenizer=tokenizer,
|
63
|
+
unet=unet,
|
64
|
+
low_res_scheduler=low_res_scheduler,
|
65
|
+
scheduler=scheduler,
|
66
|
+
safety_checker=None,
|
67
|
+
feature_extractor=None,
|
68
|
+
watermarker=None,
|
69
|
+
max_noise_level=max_noise_level,
|
70
|
+
)
|
60
71
|
|
61
72
|
def __call__(
|
62
73
|
self,
|
@@ -20,7 +20,7 @@ from packaging import version
|
|
20
20
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
21
21
|
|
22
22
|
from ...configuration_utils import FrozenDict
|
23
|
-
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
23
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
25
25
|
from ...schedulers import KarrasDiffusionSchedulers
|
26
26
|
from ...utils import (
|
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
|
|
53
53
|
"""
|
54
54
|
|
55
55
|
|
56
|
-
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
56
|
+
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
57
57
|
r"""
|
58
58
|
Pipeline for text-to-image generation using Stable Diffusion.
|
59
59
|
|
60
60
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
61
61
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
62
62
|
|
63
|
+
In addition the pipeline inherits the following loading methods:
|
64
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
65
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
66
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
67
|
+
|
68
|
+
as well as the following saving methods:
|
69
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
70
|
+
|
63
71
|
Args:
|
64
72
|
vae ([`AutoencoderKL`]):
|
65
73
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
import inspect
|
16
16
|
import math
|
17
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
17
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
@@ -76,7 +76,7 @@ class AttentionStore:
|
|
76
76
|
|
77
77
|
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
78
78
|
if self.cur_att_layer >= 0 and is_cross:
|
79
|
-
if attn.shape[1] == self.attn_res
|
79
|
+
if attn.shape[1] == np.prod(self.attn_res):
|
80
80
|
self.step_store[place_in_unet].append(attn)
|
81
81
|
|
82
82
|
self.cur_att_layer += 1
|
@@ -98,7 +98,7 @@ class AttentionStore:
|
|
98
98
|
attention_maps = self.get_average_attention()
|
99
99
|
for location in from_where:
|
100
100
|
for item in attention_maps[location]:
|
101
|
-
cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
|
101
|
+
cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
|
102
102
|
out.append(cross_maps)
|
103
103
|
out = torch.cat(out, dim=0)
|
104
104
|
out = out.sum(0) / out.shape[0]
|
@@ -109,7 +109,7 @@ class AttentionStore:
|
|
109
109
|
self.step_store = self.get_empty_store()
|
110
110
|
self.attention_store = {}
|
111
111
|
|
112
|
-
def __init__(self, attn_res
|
112
|
+
def __init__(self, attn_res):
|
113
113
|
"""
|
114
114
|
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
|
115
115
|
process
|
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
724
724
|
max_iter_to_alter: int = 25,
|
725
725
|
thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
|
726
726
|
scale_factor: int = 20,
|
727
|
-
attn_res: int = 16,
|
727
|
+
attn_res: Optional[Tuple[int]] = (16, 16),
|
728
728
|
):
|
729
729
|
r"""
|
730
730
|
Function invoked when calling the pipeline for generation.
|
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
796
796
|
Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
|
797
797
|
scale_factor (`int`, *optional*, default to 20):
|
798
798
|
Scale factor that controls the step size of each Attend and Excite update.
|
799
|
-
attn_res (`
|
800
|
-
The resolution of
|
799
|
+
attn_res (`tuple`, *optional*, default computed from width and height):
|
800
|
+
The 2D resolution of the semantic attention map.
|
801
801
|
|
802
802
|
Examples:
|
803
803
|
|
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
|
|
870
870
|
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
871
871
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
872
872
|
|
873
|
-
|
873
|
+
if attn_res is None:
|
874
|
+
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
|
875
|
+
self.attention_store = AttentionStore(attn_res)
|
874
876
|
self.register_attention_control()
|
875
877
|
|
876
878
|
# default config for step size from original repo
|
@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
|
|
118
118
|
timestep_cond: Optional[torch.Tensor] = None,
|
119
119
|
attention_mask: Optional[torch.Tensor] = None,
|
120
120
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
121
|
+
guess_mode: bool = False,
|
121
122
|
return_dict: bool = True,
|
122
123
|
) -> Union[ControlNetOutput, Tuple]:
|
123
124
|
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
|
|
131
132
|
timestep_cond,
|
132
133
|
attention_mask,
|
133
134
|
cross_attention_kwargs,
|
135
|
+
guess_mode,
|
134
136
|
return_dict,
|
135
137
|
)
|
136
138
|
|
@@ -154,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
154
156
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
155
157
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
156
158
|
|
159
|
+
In addition the pipeline inherits the following loading methods:
|
160
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
161
|
+
|
157
162
|
Args:
|
158
163
|
vae ([`AutoencoderKL`]):
|
159
164
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -244,6 +249,24 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
244
249
|
"""
|
245
250
|
self.vae.disable_slicing()
|
246
251
|
|
252
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
253
|
+
def enable_vae_tiling(self):
|
254
|
+
r"""
|
255
|
+
Enable tiled VAE decoding.
|
256
|
+
|
257
|
+
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
258
|
+
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
259
|
+
"""
|
260
|
+
self.vae.enable_tiling()
|
261
|
+
|
262
|
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
263
|
+
def disable_vae_tiling(self):
|
264
|
+
r"""
|
265
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
266
|
+
computing decoding in one step.
|
267
|
+
"""
|
268
|
+
self.vae.disable_tiling()
|
269
|
+
|
247
270
|
def enable_sequential_cpu_offload(self, gpu_id=0):
|
248
271
|
r"""
|
249
272
|
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
@@ -627,7 +650,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
627
650
|
)
|
628
651
|
|
629
652
|
def prepare_image(
|
630
|
-
self,
|
653
|
+
self,
|
654
|
+
image,
|
655
|
+
width,
|
656
|
+
height,
|
657
|
+
batch_size,
|
658
|
+
num_images_per_prompt,
|
659
|
+
device,
|
660
|
+
dtype,
|
661
|
+
do_classifier_free_guidance=False,
|
662
|
+
guess_mode=False,
|
631
663
|
):
|
632
664
|
if not isinstance(image, torch.Tensor):
|
633
665
|
if isinstance(image, PIL.Image.Image):
|
@@ -664,7 +696,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
664
696
|
|
665
697
|
image = image.to(device=device, dtype=dtype)
|
666
698
|
|
667
|
-
if do_classifier_free_guidance:
|
699
|
+
if do_classifier_free_guidance and not guess_mode:
|
668
700
|
image = torch.cat([image] * 2)
|
669
701
|
|
670
702
|
return image
|
@@ -747,6 +779,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
747
779
|
callback_steps: int = 1,
|
748
780
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
749
781
|
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
782
|
+
guess_mode: bool = False,
|
750
783
|
):
|
751
784
|
r"""
|
752
785
|
Function invoked when calling the pipeline for generation.
|
@@ -819,6 +852,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
819
852
|
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
820
853
|
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
821
854
|
corresponding scale as a list.
|
855
|
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
856
|
+
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
857
|
+
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
858
|
+
|
822
859
|
Examples:
|
823
860
|
|
824
861
|
Returns:
|
@@ -883,6 +920,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
883
920
|
device=device,
|
884
921
|
dtype=self.controlnet.dtype,
|
885
922
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
923
|
+
guess_mode=guess_mode,
|
886
924
|
)
|
887
925
|
elif isinstance(self.controlnet, MultiControlNetModel):
|
888
926
|
images = []
|
@@ -897,6 +935,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
897
935
|
device=device,
|
898
936
|
dtype=self.controlnet.dtype,
|
899
937
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
938
|
+
guess_mode=guess_mode,
|
900
939
|
)
|
901
940
|
|
902
941
|
images.append(image_)
|
@@ -934,15 +973,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
934
973
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
935
974
|
|
936
975
|
# controlnet(s) inference
|
976
|
+
if guess_mode and do_classifier_free_guidance:
|
977
|
+
# Infer ControlNet only for the conditional batch.
|
978
|
+
controlnet_latent_model_input = latents
|
979
|
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
980
|
+
else:
|
981
|
+
controlnet_latent_model_input = latent_model_input
|
982
|
+
controlnet_prompt_embeds = prompt_embeds
|
983
|
+
|
937
984
|
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
938
|
-
|
985
|
+
controlnet_latent_model_input,
|
939
986
|
t,
|
940
|
-
encoder_hidden_states=
|
987
|
+
encoder_hidden_states=controlnet_prompt_embeds,
|
941
988
|
controlnet_cond=image,
|
942
989
|
conditioning_scale=controlnet_conditioning_scale,
|
990
|
+
guess_mode=guess_mode,
|
943
991
|
return_dict=False,
|
944
992
|
)
|
945
993
|
|
994
|
+
if guess_mode and do_classifier_free_guidance:
|
995
|
+
# Infered ControlNet only for the conditional batch.
|
996
|
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
997
|
+
# add 0 to the unconditional batch to keep it unchanged.
|
998
|
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
999
|
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
1000
|
+
|
946
1001
|
# predict the noise residual
|
947
1002
|
noise_pred = self.unet(
|
948
1003
|
latent_model_input,
|
@@ -23,7 +23,7 @@ from packaging import version
|
|
23
23
|
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
24
24
|
|
25
25
|
from ...configuration_utils import FrozenDict
|
26
|
-
from ...loaders import TextualInversionLoaderMixin
|
26
|
+
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
27
27
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
28
28
|
from ...schedulers import KarrasDiffusionSchedulers
|
29
29
|
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
@@ -55,13 +55,20 @@ def preprocess(image):
|
|
55
55
|
return image
|
56
56
|
|
57
57
|
|
58
|
-
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
58
|
+
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
59
59
|
r"""
|
60
60
|
Pipeline for text-guided image to image generation using Stable Diffusion.
|
61
61
|
|
62
62
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
63
63
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
64
64
|
|
65
|
+
In addition the pipeline inherits the following loading methods:
|
66
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
67
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
68
|
+
|
69
|
+
as well as the following saving methods:
|
70
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
71
|
+
|
65
72
|
Args:
|
66
73
|
vae ([`AutoencoderKL`]):
|
67
74
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
|
23
23
|
|
24
24
|
from ...configuration_utils import FrozenDict
|
25
25
|
from ...image_processor import VaeImageProcessor
|
26
|
-
from ...loaders import TextualInversionLoaderMixin
|
26
|
+
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
27
27
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
28
28
|
from ...schedulers import KarrasDiffusionSchedulers
|
29
29
|
from ...utils import (
|
@@ -92,13 +92,21 @@ def preprocess(image):
|
|
92
92
|
return image
|
93
93
|
|
94
94
|
|
95
|
-
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
95
|
+
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
|
96
96
|
r"""
|
97
97
|
Pipeline for text-guided image to image generation using Stable Diffusion.
|
98
98
|
|
99
99
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
100
100
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
101
101
|
|
102
|
+
In addition the pipeline inherits the following loading methods:
|
103
|
+
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
104
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
105
|
+
- *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
|
106
|
+
|
107
|
+
as well as the following saving methods:
|
108
|
+
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
|
109
|
+
|
102
110
|
Args:
|
103
111
|
vae ([`AutoencoderKL`]):
|
104
112
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|