diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import importlib
19
19
  import inspect
20
20
  import os
21
21
  import re
22
+ import sys
22
23
  import warnings
23
24
  from dataclasses import dataclass
24
25
  from pathlib import Path
@@ -29,7 +30,6 @@ import PIL
29
30
  import torch
30
31
  from huggingface_hub import hf_hub_download, model_info, snapshot_download
31
32
  from packaging import version
32
- from PIL import Image
33
33
  from tqdm.auto import tqdm
34
34
 
35
35
  import diffusers
@@ -55,6 +55,7 @@ from ..utils import (
55
55
  is_torch_version,
56
56
  is_transformers_available,
57
57
  logging,
58
+ numpy_to_pil,
58
59
  )
59
60
 
60
61
 
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
200
201
  # .bin, .safetensors, ...
201
202
  weight_suffixs = [w.split(".")[-1] for w in weight_names]
202
203
  # -00001-of-00002
203
- transformers_index_format = "\d{5}-of-\d{5}"
204
+ transformers_index_format = r"\d{5}-of-\d{5}"
204
205
 
205
206
  if variant is not None:
206
207
  # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
207
208
  variant_file_re = re.compile(
208
- f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
209
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
209
210
  )
210
211
  # `text_encoder/pytorch_model.bin.index.fp16.json`
211
212
  variant_index_re = re.compile(
212
- f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
214
  )
214
215
 
215
216
  # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
216
217
  non_variant_file_re = re.compile(
217
- f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
218
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
218
219
  )
219
220
  # `text_encoder/pytorch_model.bin.index.json`
220
- non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
221
+ non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
221
222
 
222
223
  if variant is not None:
223
224
  variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
540
541
  variant (`str`, *optional*):
541
542
  If specified, weights are saved in the format pytorch_model.<variant>.bin.
542
543
  """
543
- self.save_config(save_directory)
544
-
545
544
  model_index_dict = dict(self.config)
546
- model_index_dict.pop("_class_name")
547
- model_index_dict.pop("_diffusers_version")
545
+ model_index_dict.pop("_class_name", None)
546
+ model_index_dict.pop("_diffusers_version", None)
548
547
  model_index_dict.pop("_module", None)
549
548
 
550
549
  expected_modules, optional_kwargs = self._get_signature_keys(self)
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
557
556
  return True
558
557
 
559
558
  model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
560
-
561
559
  for pipeline_component_name in model_index_dict.keys():
562
560
  sub_model = getattr(self, pipeline_component_name)
563
561
  model_cls = sub_model.__class__
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
571
569
  save_method_name = None
572
570
  # search for the model's base class in LOADABLE_CLASSES
573
571
  for library_name, library_classes in LOADABLE_CLASSES.items():
574
- library = importlib.import_module(library_name)
572
+ if library_name in sys.modules:
573
+ library = importlib.import_module(library_name)
574
+ else:
575
+ logger.info(
576
+ f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
577
+ )
578
+
575
579
  for base_class, save_load_methods in library_classes.items():
576
580
  class_candidate = getattr(library, base_class, None)
577
581
  if class_candidate is not None and issubclass(model_cls, class_candidate):
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
581
585
  if save_method_name is not None:
582
586
  break
583
587
 
588
+ if save_method_name is None:
589
+ logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
590
+ # make sure that unsaveable components are not tried to be loaded afterward
591
+ self.register_to_config(**{pipeline_component_name: (None, None)})
592
+ continue
593
+
584
594
  save_method = getattr(sub_model, save_method_name)
585
595
 
586
596
  # Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
596
606
 
597
607
  save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
598
608
 
609
+ # finally save the config
610
+ self.save_config(save_directory)
611
+
599
612
  def to(
600
613
  self,
601
614
  torch_device: Optional[Union[str, torch.device]] = None,
@@ -610,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
610
623
  if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
611
624
  return False
612
625
 
613
- return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
626
+ return hasattr(module, "_hf_hook") and not isinstance(
627
+ module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
628
+ )
614
629
 
615
630
  def module_is_offloaded(module):
616
631
  if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
@@ -640,7 +655,20 @@ class DiffusionPipeline(ConfigMixin):
640
655
 
641
656
  is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
642
657
  for module in modules:
643
- module.to(torch_device, torch_dtype)
658
+ is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
659
+
660
+ if is_loaded_in_8bit and torch_dtype is not None:
661
+ logger.warning(
662
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
663
+ )
664
+
665
+ if is_loaded_in_8bit and torch_device is not None:
666
+ logger.warning(
667
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
668
+ )
669
+ else:
670
+ module.to(torch_device, torch_dtype)
671
+
644
672
  if (
645
673
  module.dtype == torch.float16
646
674
  and str(torch_device) in ["cpu"]
@@ -874,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
874
902
 
875
903
  config_dict = cls.load_config(cached_folder)
876
904
 
905
+ # pop out "_ignore_files" as it is only needed for download
906
+ config_dict.pop("_ignore_files", None)
907
+
877
908
  # 2. Define which model components should load variants
878
909
  # We retrieve the information by matching whether variant
879
910
  # model checkpoints exist in the subfolders
@@ -1045,7 +1076,7 @@ class DiffusionPipeline(ConfigMixin):
1045
1076
  return_cached_folder = kwargs.pop("return_cached_folder", False)
1046
1077
  if return_cached_folder:
1047
1078
  message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
1048
- deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
1079
+ deprecate("return_cached_folder", "0.17.0", message)
1049
1080
  return model, cached_folder
1050
1081
 
1051
1082
  return model
@@ -1191,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
1191
1222
  )
1192
1223
 
1193
1224
  config_dict = cls._dict_from_json_file(config_file)
1225
+
1226
+ ignore_filenames = config_dict.pop("_ignore_files", [])
1227
+
1194
1228
  # retrieve all folder_names that contain relevant files
1195
1229
  folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
1196
1230
 
1197
1231
  filenames = {sibling.rfilename for sibling in info.siblings}
1198
1232
  model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1199
1233
 
1234
+ # remove ignored filenames
1235
+ model_filenames = set(model_filenames) - set(ignore_filenames)
1236
+ variant_filenames = set(variant_filenames) - set(ignore_filenames)
1237
+
1200
1238
  # if the whole pipeline is cached we don't have to ping the Hub
1201
1239
  if revision in DEPRECATED_REVISION_ARGS and version.parse(
1202
1240
  version.parse(__version__).base_version
@@ -1357,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
1357
1395
  """
1358
1396
  Convert a numpy image or a batch of images to a PIL image.
1359
1397
  """
1360
- if images.ndim == 3:
1361
- images = images[None, ...]
1362
- images = (images * 255).round().astype("uint8")
1363
- if images.shape[-1] == 1:
1364
- # special case for grayscale (single channel) images
1365
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1366
- else:
1367
- pil_images = [Image.fromarray(image) for image in images]
1368
-
1369
- return pil_images
1398
+ return numpy_to_pil(images)
1370
1399
 
1371
1400
  def progress_bar(self, iterable=None, total=None):
1372
1401
  if not hasattr(self, "_progress_bar_config"):
@@ -31,33 +31,30 @@ from transformers import (
31
31
  CLIPVisionModelWithProjection,
32
32
  )
33
33
 
34
- from diffusers import (
34
+ from ...models import (
35
35
  AutoencoderKL,
36
36
  ControlNetModel,
37
+ PriorTransformer,
38
+ UNet2DConditionModel,
39
+ )
40
+ from ...schedulers import (
37
41
  DDIMScheduler,
38
42
  DDPMScheduler,
39
43
  DPMSolverMultistepScheduler,
40
44
  EulerAncestralDiscreteScheduler,
41
45
  EulerDiscreteScheduler,
42
46
  HeunDiscreteScheduler,
43
- LDMTextToImagePipeline,
44
47
  LMSDiscreteScheduler,
45
48
  PNDMScheduler,
46
- PriorTransformer,
47
- StableDiffusionControlNetPipeline,
48
- StableDiffusionPipeline,
49
- StableUnCLIPImg2ImgPipeline,
50
- StableUnCLIPPipeline,
51
49
  UnCLIPScheduler,
52
- UNet2DConditionModel,
53
50
  )
54
- from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
55
- from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
56
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
57
- from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
58
-
59
51
  from ...utils import is_omegaconf_available, is_safetensors_available, logging
60
52
  from ...utils.import_utils import BACKENDS_MAPPING
53
+ from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
54
+ from ..paint_by_example import PaintByExampleImageEncoder
55
+ from ..pipeline_utils import DiffusionPipeline
56
+ from .safety_checker import StableDiffusionSafetyChecker
57
+ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
61
58
 
62
59
 
63
60
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
990
987
  clip_stats_path: Optional[str] = None,
991
988
  controlnet: Optional[bool] = None,
992
989
  load_safety_checker: bool = True,
993
- ) -> StableDiffusionPipeline:
990
+ pipeline_class: DiffusionPipeline = None,
991
+ ) -> DiffusionPipeline:
994
992
  """
995
993
  Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
996
994
  config file.
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
1018
1016
  model_type (`str`, *optional*, defaults to `None`):
1019
1017
  The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
1020
1018
  "FrozenCLIPEmbedder", "PaintByExample"]`.
1019
+ is_img2img (`bool`, *optional*, defaults to `False`):
1020
+ Whether the model should be loaded as an img2img pipeline.
1021
1021
  extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
1022
1022
  checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
1023
1023
  `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
1026
1026
  Whether the attention computation should always be upcasted. This is necessary when running stable
1027
1027
  diffusion 2.1.
1028
1028
  device (`str`, *optional*, defaults to `None`):
1029
- The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
1030
- in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
1031
- StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1029
+ The device to use. Pass `None` to determine automatically.
1030
+ from_safetensors (`str`, *optional*, defaults to `False`):
1031
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
1032
1032
  load_safety_checker (`bool`, *optional*, defaults to `True`):
1033
1033
  Whether to load the safety checker or not. Defaults to `True`.
1034
+ pipeline_class (`str`, *optional*, defaults to `None`):
1035
+ The pipeline class to use. Pass `None` to determine automatically.
1036
+ return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1034
1037
  """
1038
+
1039
+ # import pipelines here to avoid circular import error when using from_ckpt method
1040
+ from diffusers import (
1041
+ LDMTextToImagePipeline,
1042
+ PaintByExamplePipeline,
1043
+ StableDiffusionControlNetPipeline,
1044
+ StableDiffusionPipeline,
1045
+ StableUnCLIPImg2ImgPipeline,
1046
+ StableUnCLIPPipeline,
1047
+ )
1048
+
1049
+ if pipeline_class is None:
1050
+ pipeline_class = StableDiffusionPipeline
1051
+
1035
1052
  if prediction_type == "v-prediction":
1036
1053
  prediction_type = "v_prediction"
1037
1054
 
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
1193
1210
  requires_safety_checker=False,
1194
1211
  )
1195
1212
  else:
1196
- pipe = StableDiffusionPipeline(
1213
+ pipe = pipeline_class(
1197
1214
  vae=vae,
1198
1215
  text_encoder=text_model,
1199
1216
  tokenizer=tokenizer,
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
1293
1310
  feature_extractor=feature_extractor,
1294
1311
  )
1295
1312
  else:
1296
- pipe = StableDiffusionPipeline(
1313
+ pipe = pipeline_class(
1297
1314
  vae=vae,
1298
1315
  text_encoder=text_model,
1299
1316
  tokenizer=tokenizer,
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
1320
1337
  upcast_attention: Optional[bool] = None,
1321
1338
  device: str = None,
1322
1339
  from_safetensors: bool = False,
1323
- ) -> StableDiffusionPipeline:
1340
+ ) -> DiffusionPipeline:
1324
1341
  if not is_omegaconf_available():
1325
1342
  raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
1326
1343
 
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
83
83
  ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
84
84
  ... )
85
85
  >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
86
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
86
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
87
87
  ... )
88
88
  >>> params["controlnet"] = controlnet_params
89
89
 
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
56
56
  scheduler: Any,
57
57
  max_noise_level: int = 350,
58
58
  ):
59
- super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
59
+ super().__init__(
60
+ vae=vae,
61
+ text_encoder=text_encoder,
62
+ tokenizer=tokenizer,
63
+ unet=unet,
64
+ low_res_scheduler=low_res_scheduler,
65
+ scheduler=scheduler,
66
+ safety_checker=None,
67
+ feature_extractor=None,
68
+ watermarker=None,
69
+ max_noise_level=max_noise_level,
70
+ )
60
71
 
61
72
  def __call__(
62
73
  self,
@@ -20,7 +20,7 @@ from packaging import version
20
20
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21
21
 
22
22
  from ...configuration_utils import FrozenDict
23
- from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
23
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
24
24
  from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import KarrasDiffusionSchedulers
26
26
  from ...utils import (
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
53
53
  """
54
54
 
55
55
 
56
- class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
56
+ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
57
57
  r"""
58
58
  Pipeline for text-to-image generation using Stable Diffusion.
59
59
 
60
60
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
61
61
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
62
62
 
63
+ In addition the pipeline inherits the following loading methods:
64
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
65
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
66
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
67
+
68
+ as well as the following saving methods:
69
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
70
+
63
71
  Args:
64
72
  vae ([`AutoencoderKL`]):
65
73
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -14,7 +14,7 @@
14
14
 
15
15
  import inspect
16
16
  import math
17
- from typing import Any, Callable, Dict, List, Optional, Union
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
18
 
19
19
  import numpy as np
20
20
  import torch
@@ -76,7 +76,7 @@ class AttentionStore:
76
76
 
77
77
  def __call__(self, attn, is_cross: bool, place_in_unet: str):
78
78
  if self.cur_att_layer >= 0 and is_cross:
79
- if attn.shape[1] == self.attn_res**2:
79
+ if attn.shape[1] == np.prod(self.attn_res):
80
80
  self.step_store[place_in_unet].append(attn)
81
81
 
82
82
  self.cur_att_layer += 1
@@ -98,7 +98,7 @@ class AttentionStore:
98
98
  attention_maps = self.get_average_attention()
99
99
  for location in from_where:
100
100
  for item in attention_maps[location]:
101
- cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
101
+ cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
102
102
  out.append(cross_maps)
103
103
  out = torch.cat(out, dim=0)
104
104
  out = out.sum(0) / out.shape[0]
@@ -109,7 +109,7 @@ class AttentionStore:
109
109
  self.step_store = self.get_empty_store()
110
110
  self.attention_store = {}
111
111
 
112
- def __init__(self, attn_res=16):
112
+ def __init__(self, attn_res):
113
113
  """
114
114
  Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
115
115
  process
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
724
724
  max_iter_to_alter: int = 25,
725
725
  thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
726
726
  scale_factor: int = 20,
727
- attn_res: int = 16,
727
+ attn_res: Optional[Tuple[int]] = (16, 16),
728
728
  ):
729
729
  r"""
730
730
  Function invoked when calling the pipeline for generation.
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
796
796
  Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
797
797
  scale_factor (`int`, *optional*, default to 20):
798
798
  Scale factor that controls the step size of each Attend and Excite update.
799
- attn_res (`int`, *optional*, default to 16):
800
- The resolution of most semantic attention map.
799
+ attn_res (`tuple`, *optional*, default computed from width and height):
800
+ The 2D resolution of the semantic attention map.
801
801
 
802
802
  Examples:
803
803
 
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
870
870
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
871
871
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
872
872
 
873
- self.attention_store = AttentionStore(attn_res=attn_res)
873
+ if attn_res is None:
874
+ attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
875
+ self.attention_store = AttentionStore(attn_res)
874
876
  self.register_attention_control()
875
877
 
876
878
  # default config for step size from original repo
@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
118
118
  timestep_cond: Optional[torch.Tensor] = None,
119
119
  attention_mask: Optional[torch.Tensor] = None,
120
120
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
121
+ guess_mode: bool = False,
121
122
  return_dict: bool = True,
122
123
  ) -> Union[ControlNetOutput, Tuple]:
123
124
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
131
132
  timestep_cond,
132
133
  attention_mask,
133
134
  cross_attention_kwargs,
135
+ guess_mode,
134
136
  return_dict,
135
137
  )
136
138
 
@@ -154,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
154
156
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
155
157
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
156
158
 
159
+ In addition the pipeline inherits the following loading methods:
160
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
161
+
157
162
  Args:
158
163
  vae ([`AutoencoderKL`]):
159
164
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -244,6 +249,24 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
244
249
  """
245
250
  self.vae.disable_slicing()
246
251
 
252
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
253
+ def enable_vae_tiling(self):
254
+ r"""
255
+ Enable tiled VAE decoding.
256
+
257
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
258
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
259
+ """
260
+ self.vae.enable_tiling()
261
+
262
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
263
+ def disable_vae_tiling(self):
264
+ r"""
265
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
266
+ computing decoding in one step.
267
+ """
268
+ self.vae.disable_tiling()
269
+
247
270
  def enable_sequential_cpu_offload(self, gpu_id=0):
248
271
  r"""
249
272
  Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -627,7 +650,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
627
650
  )
628
651
 
629
652
  def prepare_image(
630
- self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
653
+ self,
654
+ image,
655
+ width,
656
+ height,
657
+ batch_size,
658
+ num_images_per_prompt,
659
+ device,
660
+ dtype,
661
+ do_classifier_free_guidance=False,
662
+ guess_mode=False,
631
663
  ):
632
664
  if not isinstance(image, torch.Tensor):
633
665
  if isinstance(image, PIL.Image.Image):
@@ -664,7 +696,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
664
696
 
665
697
  image = image.to(device=device, dtype=dtype)
666
698
 
667
- if do_classifier_free_guidance:
699
+ if do_classifier_free_guidance and not guess_mode:
668
700
  image = torch.cat([image] * 2)
669
701
 
670
702
  return image
@@ -747,6 +779,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
747
779
  callback_steps: int = 1,
748
780
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
749
781
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
782
+ guess_mode: bool = False,
750
783
  ):
751
784
  r"""
752
785
  Function invoked when calling the pipeline for generation.
@@ -819,6 +852,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
819
852
  The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
820
853
  to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
821
854
  corresponding scale as a list.
855
+ guess_mode (`bool`, *optional*, defaults to `False`):
856
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
857
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
858
+
822
859
  Examples:
823
860
 
824
861
  Returns:
@@ -883,6 +920,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
883
920
  device=device,
884
921
  dtype=self.controlnet.dtype,
885
922
  do_classifier_free_guidance=do_classifier_free_guidance,
923
+ guess_mode=guess_mode,
886
924
  )
887
925
  elif isinstance(self.controlnet, MultiControlNetModel):
888
926
  images = []
@@ -897,6 +935,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
897
935
  device=device,
898
936
  dtype=self.controlnet.dtype,
899
937
  do_classifier_free_guidance=do_classifier_free_guidance,
938
+ guess_mode=guess_mode,
900
939
  )
901
940
 
902
941
  images.append(image_)
@@ -934,15 +973,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
934
973
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
935
974
 
936
975
  # controlnet(s) inference
976
+ if guess_mode and do_classifier_free_guidance:
977
+ # Infer ControlNet only for the conditional batch.
978
+ controlnet_latent_model_input = latents
979
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
980
+ else:
981
+ controlnet_latent_model_input = latent_model_input
982
+ controlnet_prompt_embeds = prompt_embeds
983
+
937
984
  down_block_res_samples, mid_block_res_sample = self.controlnet(
938
- latent_model_input,
985
+ controlnet_latent_model_input,
939
986
  t,
940
- encoder_hidden_states=prompt_embeds,
987
+ encoder_hidden_states=controlnet_prompt_embeds,
941
988
  controlnet_cond=image,
942
989
  conditioning_scale=controlnet_conditioning_scale,
990
+ guess_mode=guess_mode,
943
991
  return_dict=False,
944
992
  )
945
993
 
994
+ if guess_mode and do_classifier_free_guidance:
995
+ # Infered ControlNet only for the conditional batch.
996
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
997
+ # add 0 to the unconditional batch to keep it unchanged.
998
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
999
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1000
+
946
1001
  # predict the noise residual
947
1002
  noise_pred = self.unet(
948
1003
  latent_model_input,
@@ -23,7 +23,7 @@ from packaging import version
23
23
  from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
24
24
 
25
25
  from ...configuration_utils import FrozenDict
26
- from ...loaders import TextualInversionLoaderMixin
26
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
27
  from ...models import AutoencoderKL, UNet2DConditionModel
28
28
  from ...schedulers import KarrasDiffusionSchedulers
29
29
  from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
@@ -55,13 +55,20 @@ def preprocess(image):
55
55
  return image
56
56
 
57
57
 
58
- class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
58
+ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
59
59
  r"""
60
60
  Pipeline for text-guided image to image generation using Stable Diffusion.
61
61
 
62
62
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
63
63
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
64
64
 
65
+ In addition the pipeline inherits the following loading methods:
66
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
67
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
68
+
69
+ as well as the following saving methods:
70
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
71
+
65
72
  Args:
66
73
  vae ([`AutoencoderKL`]):
67
74
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
25
  from ...image_processor import VaeImageProcessor
26
- from ...loaders import TextualInversionLoaderMixin
26
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
27
  from ...models import AutoencoderKL, UNet2DConditionModel
28
28
  from ...schedulers import KarrasDiffusionSchedulers
29
29
  from ...utils import (
@@ -92,13 +92,21 @@ def preprocess(image):
92
92
  return image
93
93
 
94
94
 
95
- class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
95
+ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
96
96
  r"""
97
97
  Pipeline for text-guided image to image generation using Stable Diffusion.
98
98
 
99
99
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
100
100
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
101
101
 
102
+ In addition the pipeline inherits the following loading methods:
103
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
104
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
105
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
106
+
107
+ as well as the following saving methods:
108
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
109
+
102
110
  Args:
103
111
  vae ([`AutoencoderKL`]):
104
112
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.