diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import importlib
19
19
  import inspect
20
20
  import os
21
21
  import re
22
+ import sys
22
23
  import warnings
23
24
  from dataclasses import dataclass
24
25
  from pathlib import Path
@@ -29,7 +30,6 @@ import PIL
29
30
  import torch
30
31
  from huggingface_hub import hf_hub_download, model_info, snapshot_download
31
32
  from packaging import version
32
- from PIL import Image
33
33
  from tqdm.auto import tqdm
34
34
 
35
35
  import diffusers
@@ -55,6 +55,7 @@ from ..utils import (
55
55
  is_torch_version,
56
56
  is_transformers_available,
57
57
  logging,
58
+ numpy_to_pil,
58
59
  )
59
60
 
60
61
 
@@ -200,24 +201,24 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
200
201
  # .bin, .safetensors, ...
201
202
  weight_suffixs = [w.split(".")[-1] for w in weight_names]
202
203
  # -00001-of-00002
203
- transformers_index_format = "\d{5}-of-\d{5}"
204
+ transformers_index_format = r"\d{5}-of-\d{5}"
204
205
 
205
206
  if variant is not None:
206
207
  # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetenstors`
207
208
  variant_file_re = re.compile(
208
- f"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
209
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
209
210
  )
210
211
  # `text_encoder/pytorch_model.bin.index.fp16.json`
211
212
  variant_index_re = re.compile(
212
- f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
+ rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
214
  )
214
215
 
215
216
  # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetenstors`
216
217
  non_variant_file_re = re.compile(
217
- f"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
218
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
218
219
  )
219
220
  # `text_encoder/pytorch_model.bin.index.json`
220
- non_variant_index_re = re.compile(f"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
221
+ non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json")
221
222
 
222
223
  if variant is not None:
223
224
  variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None}
@@ -540,11 +541,9 @@ class DiffusionPipeline(ConfigMixin):
540
541
  variant (`str`, *optional*):
541
542
  If specified, weights are saved in the format pytorch_model.<variant>.bin.
542
543
  """
543
- self.save_config(save_directory)
544
-
545
544
  model_index_dict = dict(self.config)
546
- model_index_dict.pop("_class_name")
547
- model_index_dict.pop("_diffusers_version")
545
+ model_index_dict.pop("_class_name", None)
546
+ model_index_dict.pop("_diffusers_version", None)
548
547
  model_index_dict.pop("_module", None)
549
548
 
550
549
  expected_modules, optional_kwargs = self._get_signature_keys(self)
@@ -557,7 +556,6 @@ class DiffusionPipeline(ConfigMixin):
557
556
  return True
558
557
 
559
558
  model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
560
-
561
559
  for pipeline_component_name in model_index_dict.keys():
562
560
  sub_model = getattr(self, pipeline_component_name)
563
561
  model_cls = sub_model.__class__
@@ -571,7 +569,13 @@ class DiffusionPipeline(ConfigMixin):
571
569
  save_method_name = None
572
570
  # search for the model's base class in LOADABLE_CLASSES
573
571
  for library_name, library_classes in LOADABLE_CLASSES.items():
574
- library = importlib.import_module(library_name)
572
+ if library_name in sys.modules:
573
+ library = importlib.import_module(library_name)
574
+ else:
575
+ logger.info(
576
+ f"{library_name} is not installed. Cannot save {pipeline_component_name} as {library_classes} from {library_name}"
577
+ )
578
+
575
579
  for base_class, save_load_methods in library_classes.items():
576
580
  class_candidate = getattr(library, base_class, None)
577
581
  if class_candidate is not None and issubclass(model_cls, class_candidate):
@@ -581,6 +585,12 @@ class DiffusionPipeline(ConfigMixin):
581
585
  if save_method_name is not None:
582
586
  break
583
587
 
588
+ if save_method_name is None:
589
+ logger.warn(f"self.{pipeline_component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
590
+ # make sure that unsaveable components are not tried to be loaded afterward
591
+ self.register_to_config(**{pipeline_component_name: (None, None)})
592
+ continue
593
+
584
594
  save_method = getattr(sub_model, save_method_name)
585
595
 
586
596
  # Call the save method with the argument safe_serialization only if it's supported
@@ -596,6 +606,9 @@ class DiffusionPipeline(ConfigMixin):
596
606
 
597
607
  save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
598
608
 
609
+ # finally save the config
610
+ self.save_config(save_directory)
611
+
599
612
  def to(
600
613
  self,
601
614
  torch_device: Optional[Union[str, torch.device]] = None,
@@ -610,7 +623,9 @@ class DiffusionPipeline(ConfigMixin):
610
623
  if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
611
624
  return False
612
625
 
613
- return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
626
+ return hasattr(module, "_hf_hook") and not isinstance(
627
+ module._hf_hook, (accelerate.hooks.CpuOffload, accelerate.hooks.AlignDevicesHook)
628
+ )
614
629
 
615
630
  def module_is_offloaded(module):
616
631
  if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
@@ -640,7 +655,20 @@ class DiffusionPipeline(ConfigMixin):
640
655
 
641
656
  is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
642
657
  for module in modules:
643
- module.to(torch_device, torch_dtype)
658
+ is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
659
+
660
+ if is_loaded_in_8bit and torch_dtype is not None:
661
+ logger.warning(
662
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
663
+ )
664
+
665
+ if is_loaded_in_8bit and torch_device is not None:
666
+ logger.warning(
667
+ f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
668
+ )
669
+ else:
670
+ module.to(torch_device, torch_dtype)
671
+
644
672
  if (
645
673
  module.dtype == torch.float16
646
674
  and str(torch_device) in ["cpu"]
@@ -874,6 +902,9 @@ class DiffusionPipeline(ConfigMixin):
874
902
 
875
903
  config_dict = cls.load_config(cached_folder)
876
904
 
905
+ # pop out "_ignore_files" as it is only needed for download
906
+ config_dict.pop("_ignore_files", None)
907
+
877
908
  # 2. Define which model components should load variants
878
909
  # We retrieve the information by matching whether variant
879
910
  # model checkpoints exist in the subfolders
@@ -1045,7 +1076,7 @@ class DiffusionPipeline(ConfigMixin):
1045
1076
  return_cached_folder = kwargs.pop("return_cached_folder", False)
1046
1077
  if return_cached_folder:
1047
1078
  message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`."
1048
- deprecate("return_cached_folder", "0.17.0", message, take_from=kwargs)
1079
+ deprecate("return_cached_folder", "0.17.0", message)
1049
1080
  return model, cached_folder
1050
1081
 
1051
1082
  return model
@@ -1191,12 +1222,19 @@ class DiffusionPipeline(ConfigMixin):
1191
1222
  )
1192
1223
 
1193
1224
  config_dict = cls._dict_from_json_file(config_file)
1225
+
1226
+ ignore_filenames = config_dict.pop("_ignore_files", [])
1227
+
1194
1228
  # retrieve all folder_names that contain relevant files
1195
1229
  folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
1196
1230
 
1197
1231
  filenames = {sibling.rfilename for sibling in info.siblings}
1198
1232
  model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1199
1233
 
1234
+ # remove ignored filenames
1235
+ model_filenames = set(model_filenames) - set(ignore_filenames)
1236
+ variant_filenames = set(variant_filenames) - set(ignore_filenames)
1237
+
1200
1238
  # if the whole pipeline is cached we don't have to ping the Hub
1201
1239
  if revision in DEPRECATED_REVISION_ARGS and version.parse(
1202
1240
  version.parse(__version__).base_version
@@ -1357,16 +1395,7 @@ class DiffusionPipeline(ConfigMixin):
1357
1395
  """
1358
1396
  Convert a numpy image or a batch of images to a PIL image.
1359
1397
  """
1360
- if images.ndim == 3:
1361
- images = images[None, ...]
1362
- images = (images * 255).round().astype("uint8")
1363
- if images.shape[-1] == 1:
1364
- # special case for grayscale (single channel) images
1365
- pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1366
- else:
1367
- pil_images = [Image.fromarray(image) for image in images]
1368
-
1369
- return pil_images
1398
+ return numpy_to_pil(images)
1370
1399
 
1371
1400
  def progress_bar(self, iterable=None, total=None):
1372
1401
  if not hasattr(self, "_progress_bar_config"):
@@ -31,33 +31,30 @@ from transformers import (
31
31
  CLIPVisionModelWithProjection,
32
32
  )
33
33
 
34
- from diffusers import (
34
+ from ...models import (
35
35
  AutoencoderKL,
36
36
  ControlNetModel,
37
+ PriorTransformer,
38
+ UNet2DConditionModel,
39
+ )
40
+ from ...schedulers import (
37
41
  DDIMScheduler,
38
42
  DDPMScheduler,
39
43
  DPMSolverMultistepScheduler,
40
44
  EulerAncestralDiscreteScheduler,
41
45
  EulerDiscreteScheduler,
42
46
  HeunDiscreteScheduler,
43
- LDMTextToImagePipeline,
44
47
  LMSDiscreteScheduler,
45
48
  PNDMScheduler,
46
- PriorTransformer,
47
- StableDiffusionControlNetPipeline,
48
- StableDiffusionPipeline,
49
- StableUnCLIPImg2ImgPipeline,
50
- StableUnCLIPPipeline,
51
49
  UnCLIPScheduler,
52
- UNet2DConditionModel,
53
50
  )
54
- from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
55
- from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder, PaintByExamplePipeline
56
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
57
- from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
58
-
59
51
  from ...utils import is_omegaconf_available, is_safetensors_available, logging
60
52
  from ...utils.import_utils import BACKENDS_MAPPING
53
+ from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
54
+ from ..paint_by_example import PaintByExampleImageEncoder
55
+ from ..pipeline_utils import DiffusionPipeline
56
+ from .safety_checker import StableDiffusionSafetyChecker
57
+ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
61
58
 
62
59
 
63
60
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -990,7 +987,8 @@ def download_from_original_stable_diffusion_ckpt(
990
987
  clip_stats_path: Optional[str] = None,
991
988
  controlnet: Optional[bool] = None,
992
989
  load_safety_checker: bool = True,
993
- ) -> StableDiffusionPipeline:
990
+ pipeline_class: DiffusionPipeline = None,
991
+ ) -> DiffusionPipeline:
994
992
  """
995
993
  Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
996
994
  config file.
@@ -1018,6 +1016,8 @@ def download_from_original_stable_diffusion_ckpt(
1018
1016
  model_type (`str`, *optional*, defaults to `None`):
1019
1017
  The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
1020
1018
  "FrozenCLIPEmbedder", "PaintByExample"]`.
1019
+ is_img2img (`bool`, *optional*, defaults to `False`):
1020
+ Whether the model should be loaded as an img2img pipeline.
1021
1021
  extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
1022
1022
  checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
1023
1023
  `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
@@ -1026,12 +1026,29 @@ def download_from_original_stable_diffusion_ckpt(
1026
1026
  Whether the attention computation should always be upcasted. This is necessary when running stable
1027
1027
  diffusion 2.1.
1028
1028
  device (`str`, *optional*, defaults to `None`):
1029
- The device to use. Pass `None` to determine automatically. :param from_safetensors: If `checkpoint_path` is
1030
- in `safetensors` format, load checkpoint with safetensors instead of PyTorch. :return: A
1031
- StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1029
+ The device to use. Pass `None` to determine automatically.
1030
+ from_safetensors (`str`, *optional*, defaults to `False`):
1031
+ If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
1032
1032
  load_safety_checker (`bool`, *optional*, defaults to `True`):
1033
1033
  Whether to load the safety checker or not. Defaults to `True`.
1034
+ pipeline_class (`str`, *optional*, defaults to `None`):
1035
+ The pipeline class to use. Pass `None` to determine automatically.
1036
+ return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
1034
1037
  """
1038
+
1039
+ # import pipelines here to avoid circular import error when using from_ckpt method
1040
+ from diffusers import (
1041
+ LDMTextToImagePipeline,
1042
+ PaintByExamplePipeline,
1043
+ StableDiffusionControlNetPipeline,
1044
+ StableDiffusionPipeline,
1045
+ StableUnCLIPImg2ImgPipeline,
1046
+ StableUnCLIPPipeline,
1047
+ )
1048
+
1049
+ if pipeline_class is None:
1050
+ pipeline_class = StableDiffusionPipeline
1051
+
1035
1052
  if prediction_type == "v-prediction":
1036
1053
  prediction_type = "v_prediction"
1037
1054
 
@@ -1193,7 +1210,7 @@ def download_from_original_stable_diffusion_ckpt(
1193
1210
  requires_safety_checker=False,
1194
1211
  )
1195
1212
  else:
1196
- pipe = StableDiffusionPipeline(
1213
+ pipe = pipeline_class(
1197
1214
  vae=vae,
1198
1215
  text_encoder=text_model,
1199
1216
  tokenizer=tokenizer,
@@ -1293,7 +1310,7 @@ def download_from_original_stable_diffusion_ckpt(
1293
1310
  feature_extractor=feature_extractor,
1294
1311
  )
1295
1312
  else:
1296
- pipe = StableDiffusionPipeline(
1313
+ pipe = pipeline_class(
1297
1314
  vae=vae,
1298
1315
  text_encoder=text_model,
1299
1316
  tokenizer=tokenizer,
@@ -1320,7 +1337,7 @@ def download_controlnet_from_original_ckpt(
1320
1337
  upcast_attention: Optional[bool] = None,
1321
1338
  device: str = None,
1322
1339
  from_safetensors: bool = False,
1323
- ) -> StableDiffusionPipeline:
1340
+ ) -> DiffusionPipeline:
1324
1341
  if not is_omegaconf_available():
1325
1342
  raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
1326
1343
 
@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
83
83
  ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
84
84
  ... )
85
85
  >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
86
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
86
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
87
87
  ... )
88
88
  >>> params["controlnet"] = controlnet_params
89
89
 
@@ -56,7 +56,18 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline):
56
56
  scheduler: Any,
57
57
  max_noise_level: int = 350,
58
58
  ):
59
- super().__init__(vae, text_encoder, tokenizer, unet, low_res_scheduler, scheduler, max_noise_level)
59
+ super().__init__(
60
+ vae=vae,
61
+ text_encoder=text_encoder,
62
+ tokenizer=tokenizer,
63
+ unet=unet,
64
+ low_res_scheduler=low_res_scheduler,
65
+ scheduler=scheduler,
66
+ safety_checker=None,
67
+ feature_extractor=None,
68
+ watermarker=None,
69
+ max_noise_level=max_noise_level,
70
+ )
60
71
 
61
72
  def __call__(
62
73
  self,
@@ -20,7 +20,7 @@ from packaging import version
20
20
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21
21
 
22
22
  from ...configuration_utils import FrozenDict
23
- from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
23
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
24
24
  from ...models import AutoencoderKL, UNet2DConditionModel
25
25
  from ...schedulers import KarrasDiffusionSchedulers
26
26
  from ...utils import (
@@ -53,13 +53,21 @@ EXAMPLE_DOC_STRING = """
53
53
  """
54
54
 
55
55
 
56
- class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
56
+ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
57
57
  r"""
58
58
  Pipeline for text-to-image generation using Stable Diffusion.
59
59
 
60
60
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
61
61
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
62
62
 
63
+ In addition the pipeline inherits the following loading methods:
64
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
65
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
66
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
67
+
68
+ as well as the following saving methods:
69
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
70
+
63
71
  Args:
64
72
  vae ([`AutoencoderKL`]):
65
73
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -14,7 +14,7 @@
14
14
 
15
15
  import inspect
16
16
  import math
17
- from typing import Any, Callable, Dict, List, Optional, Union
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
18
 
19
19
  import numpy as np
20
20
  import torch
@@ -76,7 +76,7 @@ class AttentionStore:
76
76
 
77
77
  def __call__(self, attn, is_cross: bool, place_in_unet: str):
78
78
  if self.cur_att_layer >= 0 and is_cross:
79
- if attn.shape[1] == self.attn_res**2:
79
+ if attn.shape[1] == np.prod(self.attn_res):
80
80
  self.step_store[place_in_unet].append(attn)
81
81
 
82
82
  self.cur_att_layer += 1
@@ -98,7 +98,7 @@ class AttentionStore:
98
98
  attention_maps = self.get_average_attention()
99
99
  for location in from_where:
100
100
  for item in attention_maps[location]:
101
- cross_maps = item.reshape(-1, self.attn_res, self.attn_res, item.shape[-1])
101
+ cross_maps = item.reshape(-1, self.attn_res[0], self.attn_res[1], item.shape[-1])
102
102
  out.append(cross_maps)
103
103
  out = torch.cat(out, dim=0)
104
104
  out = out.sum(0) / out.shape[0]
@@ -109,7 +109,7 @@ class AttentionStore:
109
109
  self.step_store = self.get_empty_store()
110
110
  self.attention_store = {}
111
111
 
112
- def __init__(self, attn_res=16):
112
+ def __init__(self, attn_res):
113
113
  """
114
114
  Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
115
115
  process
@@ -724,7 +724,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
724
724
  max_iter_to_alter: int = 25,
725
725
  thresholds: dict = {0: 0.05, 10: 0.5, 20: 0.8},
726
726
  scale_factor: int = 20,
727
- attn_res: int = 16,
727
+ attn_res: Optional[Tuple[int]] = (16, 16),
728
728
  ):
729
729
  r"""
730
730
  Function invoked when calling the pipeline for generation.
@@ -796,8 +796,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
796
796
  Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
797
797
  scale_factor (`int`, *optional*, default to 20):
798
798
  Scale factor that controls the step size of each Attend and Excite update.
799
- attn_res (`int`, *optional*, default to 16):
800
- The resolution of most semantic attention map.
799
+ attn_res (`tuple`, *optional*, default computed from width and height):
800
+ The 2D resolution of the semantic attention map.
801
801
 
802
802
  Examples:
803
803
 
@@ -870,7 +870,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
870
870
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
871
871
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
872
872
 
873
- self.attention_store = AttentionStore(attn_res=attn_res)
873
+ if attn_res is None:
874
+ attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
875
+ self.attention_store = AttentionStore(attn_res)
874
876
  self.register_attention_control()
875
877
 
876
878
  # default config for step size from original repo
@@ -118,6 +118,7 @@ class MultiControlNetModel(ModelMixin):
118
118
  timestep_cond: Optional[torch.Tensor] = None,
119
119
  attention_mask: Optional[torch.Tensor] = None,
120
120
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
121
+ guess_mode: bool = False,
121
122
  return_dict: bool = True,
122
123
  ) -> Union[ControlNetOutput, Tuple]:
123
124
  for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
@@ -131,6 +132,7 @@ class MultiControlNetModel(ModelMixin):
131
132
  timestep_cond,
132
133
  attention_mask,
133
134
  cross_attention_kwargs,
135
+ guess_mode,
134
136
  return_dict,
135
137
  )
136
138
 
@@ -154,6 +156,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
154
156
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
155
157
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
156
158
 
159
+ In addition the pipeline inherits the following loading methods:
160
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
161
+
157
162
  Args:
158
163
  vae ([`AutoencoderKL`]):
159
164
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -244,6 +249,24 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
244
249
  """
245
250
  self.vae.disable_slicing()
246
251
 
252
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
253
+ def enable_vae_tiling(self):
254
+ r"""
255
+ Enable tiled VAE decoding.
256
+
257
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
258
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
259
+ """
260
+ self.vae.enable_tiling()
261
+
262
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
263
+ def disable_vae_tiling(self):
264
+ r"""
265
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
266
+ computing decoding in one step.
267
+ """
268
+ self.vae.disable_tiling()
269
+
247
270
  def enable_sequential_cpu_offload(self, gpu_id=0):
248
271
  r"""
249
272
  Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -627,7 +650,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
627
650
  )
628
651
 
629
652
  def prepare_image(
630
- self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance
653
+ self,
654
+ image,
655
+ width,
656
+ height,
657
+ batch_size,
658
+ num_images_per_prompt,
659
+ device,
660
+ dtype,
661
+ do_classifier_free_guidance=False,
662
+ guess_mode=False,
631
663
  ):
632
664
  if not isinstance(image, torch.Tensor):
633
665
  if isinstance(image, PIL.Image.Image):
@@ -664,7 +696,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
664
696
 
665
697
  image = image.to(device=device, dtype=dtype)
666
698
 
667
- if do_classifier_free_guidance:
699
+ if do_classifier_free_guidance and not guess_mode:
668
700
  image = torch.cat([image] * 2)
669
701
 
670
702
  return image
@@ -747,6 +779,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
747
779
  callback_steps: int = 1,
748
780
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
749
781
  controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
782
+ guess_mode: bool = False,
750
783
  ):
751
784
  r"""
752
785
  Function invoked when calling the pipeline for generation.
@@ -819,6 +852,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
819
852
  The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
820
853
  to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
821
854
  corresponding scale as a list.
855
+ guess_mode (`bool`, *optional*, defaults to `False`):
856
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
857
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
858
+
822
859
  Examples:
823
860
 
824
861
  Returns:
@@ -883,6 +920,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
883
920
  device=device,
884
921
  dtype=self.controlnet.dtype,
885
922
  do_classifier_free_guidance=do_classifier_free_guidance,
923
+ guess_mode=guess_mode,
886
924
  )
887
925
  elif isinstance(self.controlnet, MultiControlNetModel):
888
926
  images = []
@@ -897,6 +935,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
897
935
  device=device,
898
936
  dtype=self.controlnet.dtype,
899
937
  do_classifier_free_guidance=do_classifier_free_guidance,
938
+ guess_mode=guess_mode,
900
939
  )
901
940
 
902
941
  images.append(image_)
@@ -934,15 +973,31 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
934
973
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
935
974
 
936
975
  # controlnet(s) inference
976
+ if guess_mode and do_classifier_free_guidance:
977
+ # Infer ControlNet only for the conditional batch.
978
+ controlnet_latent_model_input = latents
979
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
980
+ else:
981
+ controlnet_latent_model_input = latent_model_input
982
+ controlnet_prompt_embeds = prompt_embeds
983
+
937
984
  down_block_res_samples, mid_block_res_sample = self.controlnet(
938
- latent_model_input,
985
+ controlnet_latent_model_input,
939
986
  t,
940
- encoder_hidden_states=prompt_embeds,
987
+ encoder_hidden_states=controlnet_prompt_embeds,
941
988
  controlnet_cond=image,
942
989
  conditioning_scale=controlnet_conditioning_scale,
990
+ guess_mode=guess_mode,
943
991
  return_dict=False,
944
992
  )
945
993
 
994
+ if guess_mode and do_classifier_free_guidance:
995
+ # Infered ControlNet only for the conditional batch.
996
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
997
+ # add 0 to the unconditional batch to keep it unchanged.
998
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
999
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1000
+
946
1001
  # predict the noise residual
947
1002
  noise_pred = self.unet(
948
1003
  latent_model_input,
@@ -23,7 +23,7 @@ from packaging import version
23
23
  from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
24
24
 
25
25
  from ...configuration_utils import FrozenDict
26
- from ...loaders import TextualInversionLoaderMixin
26
+ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
27
27
  from ...models import AutoencoderKL, UNet2DConditionModel
28
28
  from ...schedulers import KarrasDiffusionSchedulers
29
29
  from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
@@ -55,13 +55,20 @@ def preprocess(image):
55
55
  return image
56
56
 
57
57
 
58
- class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
58
+ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
59
59
  r"""
60
60
  Pipeline for text-guided image to image generation using Stable Diffusion.
61
61
 
62
62
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
63
63
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
64
64
 
65
+ In addition the pipeline inherits the following loading methods:
66
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
67
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
68
+
69
+ as well as the following saving methods:
70
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
71
+
65
72
  Args:
66
73
  vae ([`AutoencoderKL`]):
67
74
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
23
23
 
24
24
  from ...configuration_utils import FrozenDict
25
25
  from ...image_processor import VaeImageProcessor
26
- from ...loaders import TextualInversionLoaderMixin
26
+ from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
27
  from ...models import AutoencoderKL, UNet2DConditionModel
28
28
  from ...schedulers import KarrasDiffusionSchedulers
29
29
  from ...utils import (
@@ -92,13 +92,21 @@ def preprocess(image):
92
92
  return image
93
93
 
94
94
 
95
- class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
95
+ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin):
96
96
  r"""
97
97
  Pipeline for text-guided image to image generation using Stable Diffusion.
98
98
 
99
99
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
100
100
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
101
101
 
102
+ In addition the pipeline inherits the following loading methods:
103
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
104
+ - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
105
+ - *Ckpt*: [`loaders.FromCkptMixin.from_ckpt`]
106
+
107
+ as well as the following saving methods:
108
+ - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
109
+
102
110
  Args:
103
111
  vae ([`AutoencoderKL`]):
104
112
  Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.