diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.15.1"
1
+ __version__ = "0.16.1"
2
2
 
3
3
  from .configuration_utils import ConfigMixin
4
4
  from .utils import (
@@ -109,12 +109,17 @@ try:
109
109
  except OptionalDependencyNotAvailable:
110
110
  from .utils.dummy_torch_and_transformers_objects import * # noqa F403
111
111
  else:
112
- from .loaders import TextualInversionLoaderMixin
113
112
  from .pipelines import (
114
113
  AltDiffusionImg2ImgPipeline,
115
114
  AltDiffusionPipeline,
116
115
  AudioLDMPipeline,
117
116
  CycleDiffusionPipeline,
117
+ IFImg2ImgPipeline,
118
+ IFImg2ImgSuperResolutionPipeline,
119
+ IFInpaintingPipeline,
120
+ IFInpaintingSuperResolutionPipeline,
121
+ IFPipeline,
122
+ IFSuperResolutionPipeline,
118
123
  LDMTextToImagePipeline,
119
124
  PaintByExamplePipeline,
120
125
  SemanticStableDiffusionPipeline,
@@ -109,6 +109,7 @@ class ConfigMixin:
109
109
  # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
110
110
  # or solve in a more general way.
111
111
  kwargs.pop("kwargs", None)
112
+
112
113
  if not hasattr(self, "_internal_dict"):
113
114
  internal_dict = kwargs
114
115
  else:
@@ -550,6 +551,9 @@ class ConfigMixin:
550
551
  return value
551
552
 
552
553
  config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
554
+ # Don't save "_ignore_files"
555
+ config_dict.pop("_ignore_files", None)
556
+
553
557
  return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
554
558
 
555
559
  def to_json_file(self, json_file_path: Union[str, os.PathLike]):
diffusers/loaders.py CHANGED
@@ -13,11 +13,17 @@
13
13
  # limitations under the License.
14
14
  import os
15
15
  from collections import defaultdict
16
+ from pathlib import Path
16
17
  from typing import Callable, Dict, List, Optional, Union
17
18
 
18
19
  import torch
20
+ from huggingface_hub import hf_hub_download
19
21
 
20
- from .models.attention_processor import LoRAAttnProcessor
22
+ from .models.attention_processor import (
23
+ CustomDiffusionAttnProcessor,
24
+ CustomDiffusionXFormersAttnProcessor,
25
+ LoRAAttnProcessor,
26
+ )
21
27
  from .utils import (
22
28
  DIFFUSERS_CACHE,
23
29
  HF_HUB_OFFLINE,
@@ -46,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
46
52
  TEXT_INVERSION_NAME = "learned_embeds.bin"
47
53
  TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
48
54
 
55
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
56
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
57
+
49
58
 
50
59
  class AttnProcsLayers(torch.nn.Module):
51
60
  def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -213,6 +222,7 @@ class UNet2DConditionLoadersMixin:
213
222
  attn_processors = {}
214
223
 
215
224
  is_lora = all("lora" in k for k in state_dict.keys())
225
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
216
226
 
217
227
  if is_lora:
218
228
  lora_grouped_dict = defaultdict(dict)
@@ -229,9 +239,38 @@ class UNet2DConditionLoadersMixin:
229
239
  hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
230
240
  )
231
241
  attn_processors[key].load_state_dict(value_dict)
232
-
242
+ elif is_custom_diffusion:
243
+ custom_diffusion_grouped_dict = defaultdict(dict)
244
+ for key, value in state_dict.items():
245
+ if len(value) == 0:
246
+ custom_diffusion_grouped_dict[key] = {}
247
+ else:
248
+ if "to_out" in key:
249
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
250
+ else:
251
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
252
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
253
+
254
+ for key, value_dict in custom_diffusion_grouped_dict.items():
255
+ if len(value_dict) == 0:
256
+ attn_processors[key] = CustomDiffusionAttnProcessor(
257
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
258
+ )
259
+ else:
260
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
261
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
262
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
263
+ attn_processors[key] = CustomDiffusionAttnProcessor(
264
+ train_kv=True,
265
+ train_q_out=train_q_out,
266
+ hidden_size=hidden_size,
267
+ cross_attention_dim=cross_attention_dim,
268
+ )
269
+ attn_processors[key].load_state_dict(value_dict)
233
270
  else:
234
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
271
+ raise ValueError(
272
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
273
+ )
235
274
 
236
275
  # set correct dtype & device
237
276
  attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
@@ -285,16 +324,31 @@ class UNet2DConditionLoadersMixin:
285
324
 
286
325
  os.makedirs(save_directory, exist_ok=True)
287
326
 
288
- model_to_save = AttnProcsLayers(self.attn_processors)
289
-
290
- # Save the model
291
- state_dict = model_to_save.state_dict()
327
+ is_custom_diffusion = any(
328
+ isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
329
+ for (_, x) in self.attn_processors.items()
330
+ )
331
+ if is_custom_diffusion:
332
+ model_to_save = AttnProcsLayers(
333
+ {
334
+ y: x
335
+ for (y, x) in self.attn_processors.items()
336
+ if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
337
+ }
338
+ )
339
+ state_dict = model_to_save.state_dict()
340
+ for name, attn in self.attn_processors.items():
341
+ if len(attn.state_dict()) == 0:
342
+ state_dict[name] = {}
343
+ else:
344
+ model_to_save = AttnProcsLayers(self.attn_processors)
345
+ state_dict = model_to_save.state_dict()
292
346
 
293
347
  if weight_name is None:
294
348
  if safe_serialization:
295
- weight_name = LORA_WEIGHT_NAME_SAFE
349
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
296
350
  else:
297
- weight_name = LORA_WEIGHT_NAME
351
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
298
352
 
299
353
  # Save the model
300
354
  save_function(state_dict, os.path.join(save_directory, weight_name))
@@ -356,7 +410,7 @@ class TextualInversionLoaderMixin:
356
410
  replacement = token
357
411
  i = 1
358
412
  while f"{token}_{i}" in tokenizer.added_tokens_encoder:
359
- replacement += f"{token}_{i}"
413
+ replacement += f" {token}_{i}"
360
414
  i += 1
361
415
 
362
416
  prompt = prompt.replace(token, replacement)
@@ -431,6 +485,7 @@ class TextualInversionLoaderMixin:
431
485
  Example:
432
486
 
433
487
  To load a textual inversion embedding vector in `diffusers` format:
488
+
434
489
  ```py
435
490
  from diffusers import StableDiffusionPipeline
436
491
  import torch
@@ -456,13 +511,14 @@ class TextualInversionLoaderMixin:
456
511
  model_id = "runwayml/stable-diffusion-v1-5"
457
512
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
458
513
 
459
- pipe.load_textual_inversion("./charturnerv2.pt")
514
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
460
515
 
461
516
  prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
462
517
 
463
518
  image = pipe(prompt, num_inference_steps=50).images[0]
464
519
  image.save("character.png")
465
520
  ```
521
+
466
522
  """
467
523
  if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
468
524
  raise ValueError(
@@ -792,7 +848,7 @@ class LoraLoaderMixin:
792
848
  """
793
849
  # Loop over the original attention modules.
794
850
  for name, _ in self.text_encoder.named_modules():
795
- if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
851
+ if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
796
852
  # Retrieve the module and its corresponding LoRA processor.
797
853
  module = self.text_encoder.get_submodule(name)
798
854
  # Construct a new function that performs the LoRA merging. We will monkey patch
@@ -1051,3 +1107,197 @@ class LoraLoaderMixin:
1051
1107
 
1052
1108
  save_function(state_dict, os.path.join(save_directory, weight_name))
1053
1109
  logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1110
+
1111
+
1112
+ class FromCkptMixin:
1113
+ """This helper class allows to directly load .ckpt stable diffusion file_extension
1114
+ into the respective classes."""
1115
+
1116
+ @classmethod
1117
+ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
1118
+ r"""
1119
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
1120
+
1121
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
1122
+
1123
+ Parameters:
1124
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1125
+ Can be either:
1126
+ - A link to the .ckpt file on the Hub. Should be in the format
1127
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
1128
+ - A path to a *file* containing all pipeline weights.
1129
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1130
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
1131
+ will be automatically derived from the model's weights.
1132
+ force_download (`bool`, *optional*, defaults to `False`):
1133
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1134
+ cached versions if they exist.
1135
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1136
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1137
+ standard cache should not be used.
1138
+ resume_download (`bool`, *optional*, defaults to `False`):
1139
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1140
+ file exists.
1141
+ proxies (`Dict[str, str]`, *optional*):
1142
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1143
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1144
+ local_files_only (`bool`, *optional*, defaults to `False`):
1145
+ Whether or not to only look at local files (i.e., do not try to download the model).
1146
+ use_auth_token (`str` or *bool*, *optional*):
1147
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1148
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
1149
+ revision (`str`, *optional*, defaults to `"main"`):
1150
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1151
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1152
+ identifier allowed by git.
1153
+ use_safetensors (`bool`, *optional* ):
1154
+ If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
1155
+ default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
1156
+ `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
1157
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
1158
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
1159
+ to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
1160
+ inference. Non-EMA weights are usually better to continue fine-tuning.
1161
+ upcast_attention (`bool`, *optional*, defaults to `None`):
1162
+ Whether the attention computation should always be upcasted. This is necessary when running stable
1163
+ image_size (`int`, *optional*, defaults to 512):
1164
+ The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
1165
+ Base. Use 768 for Stable Diffusion v2.
1166
+ prediction_type (`str`, *optional*):
1167
+ The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
1168
+ Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
1169
+ num_in_channels (`int`, *optional*, defaults to None):
1170
+ The number of input channels. If `None`, it will be automatically inferred.
1171
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
1172
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1173
+ "ddim"]`.
1174
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
1175
+ Whether to load the safety checker or not. Defaults to `True`.
1176
+ kwargs (remaining dictionary of keyword arguments, *optional*):
1177
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
1178
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
1179
+ `__init__` method. See example below for more information.
1180
+
1181
+ Examples:
1182
+
1183
+ ```py
1184
+ >>> from diffusers import StableDiffusionPipeline
1185
+
1186
+ >>> # Download pipeline from huggingface.co and cache.
1187
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1188
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1189
+ ... )
1190
+
1191
+ >>> # Download pipeline from local file
1192
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1193
+ >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
1194
+
1195
+ >>> # Enable float16 and move to GPU
1196
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1197
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1198
+ ... torch_dtype=torch.float16,
1199
+ ... )
1200
+ >>> pipeline.to("cuda")
1201
+ ```
1202
+ """
1203
+ # import here to avoid circular dependency
1204
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1205
+
1206
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1207
+ resume_download = kwargs.pop("resume_download", False)
1208
+ force_download = kwargs.pop("force_download", False)
1209
+ proxies = kwargs.pop("proxies", None)
1210
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1211
+ use_auth_token = kwargs.pop("use_auth_token", None)
1212
+ revision = kwargs.pop("revision", None)
1213
+ extract_ema = kwargs.pop("extract_ema", False)
1214
+ image_size = kwargs.pop("image_size", 512)
1215
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
1216
+ num_in_channels = kwargs.pop("num_in_channels", None)
1217
+ upcast_attention = kwargs.pop("upcast_attention", None)
1218
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
1219
+ prediction_type = kwargs.pop("prediction_type", None)
1220
+
1221
+ torch_dtype = kwargs.pop("torch_dtype", None)
1222
+
1223
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1224
+
1225
+ pipeline_name = cls.__name__
1226
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1227
+ from_safetensors = file_extension == "safetensors"
1228
+
1229
+ if from_safetensors and use_safetensors is True:
1230
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1231
+
1232
+ # TODO: For now we only support stable diffusion
1233
+ stable_unclip = None
1234
+ controlnet = False
1235
+
1236
+ if pipeline_name == "StableDiffusionControlNetPipeline":
1237
+ model_type = "FrozenCLIPEmbedder"
1238
+ controlnet = True
1239
+ elif "StableDiffusion" in pipeline_name:
1240
+ model_type = "FrozenCLIPEmbedder"
1241
+ elif pipeline_name == "StableUnCLIPPipeline":
1242
+ model_type == "FrozenOpenCLIPEmbedder"
1243
+ stable_unclip = "txt2img"
1244
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1245
+ model_type == "FrozenOpenCLIPEmbedder"
1246
+ stable_unclip = "img2img"
1247
+ elif pipeline_name == "PaintByExamplePipeline":
1248
+ model_type == "PaintByExample"
1249
+ elif pipeline_name == "LDMTextToImagePipeline":
1250
+ model_type == "LDMTextToImage"
1251
+ else:
1252
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1253
+
1254
+ # remove huggingface url
1255
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1256
+ if pretrained_model_link_or_path.startswith(prefix):
1257
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1258
+
1259
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1260
+ ckpt_path = Path(pretrained_model_link_or_path)
1261
+ if not ckpt_path.is_file():
1262
+ # get repo_id and (potentially nested) file path of ckpt in repo
1263
+ repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
1264
+ file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
1265
+
1266
+ if file_path.startswith("blob/"):
1267
+ file_path = file_path[len("blob/") :]
1268
+
1269
+ if file_path.startswith("main/"):
1270
+ file_path = file_path[len("main/") :]
1271
+
1272
+ pretrained_model_link_or_path = hf_hub_download(
1273
+ repo_id,
1274
+ filename=file_path,
1275
+ cache_dir=cache_dir,
1276
+ resume_download=resume_download,
1277
+ proxies=proxies,
1278
+ local_files_only=local_files_only,
1279
+ use_auth_token=use_auth_token,
1280
+ revision=revision,
1281
+ force_download=force_download,
1282
+ )
1283
+
1284
+ pipe = download_from_original_stable_diffusion_ckpt(
1285
+ pretrained_model_link_or_path,
1286
+ pipeline_class=cls,
1287
+ model_type=model_type,
1288
+ stable_unclip=stable_unclip,
1289
+ controlnet=controlnet,
1290
+ from_safetensors=from_safetensors,
1291
+ extract_ema=extract_ema,
1292
+ image_size=image_size,
1293
+ scheduler_type=scheduler_type,
1294
+ num_in_channels=num_in_channels,
1295
+ upcast_attention=upcast_attention,
1296
+ load_safety_checker=load_safety_checker,
1297
+ prediction_type=prediction_type,
1298
+ )
1299
+
1300
+ if torch_dtype is not None:
1301
+ pipe.to(torch_dtype=torch_dtype)
1302
+
1303
+ return pipe
@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
60
60
  self.channels = channels
61
61
 
62
62
  self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
- self.num_head_size = num_head_channels
64
63
  self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65
64
 
66
65
  # define q,k,v as linear layers
@@ -72,20 +71,30 @@ class AttentionBlock(nn.Module):
72
71
  self.proj_attn = nn.Linear(channels, channels, bias=True)
73
72
 
74
73
  self._use_memory_efficient_attention_xformers = False
74
+ self._use_2_0_attn = True
75
75
  self._attention_op = None
76
76
 
77
- def reshape_heads_to_batch_dim(self, tensor):
77
+ def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
78
78
  batch_size, seq_len, dim = tensor.shape
79
79
  head_size = self.num_heads
80
80
  tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
81
+ tensor = tensor.permute(0, 2, 1, 3)
82
+ if merge_head_and_batch:
83
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
82
84
  return tensor
83
85
 
84
- def reshape_batch_dim_to_heads(self, tensor):
85
- batch_size, seq_len, dim = tensor.shape
86
+ def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
86
87
  head_size = self.num_heads
87
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
88
+
89
+ if unmerge_head_and_batch:
90
+ batch_head_size, seq_len, dim = tensor.shape
91
+ batch_size = batch_head_size // head_size
92
+
93
+ tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
94
+ else:
95
+ batch_size, _, seq_len, dim = tensor.shape
96
+
97
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
89
98
  return tensor
90
99
 
91
100
  def set_use_memory_efficient_attention_xformers(
@@ -134,14 +143,24 @@ class AttentionBlock(nn.Module):
134
143
 
135
144
  scale = 1 / math.sqrt(self.channels / self.num_heads)
136
145
 
137
- query_proj = self.reshape_heads_to_batch_dim(query_proj)
138
- key_proj = self.reshape_heads_to_batch_dim(key_proj)
139
- value_proj = self.reshape_heads_to_batch_dim(value_proj)
146
+ _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
147
+ use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
148
+
149
+ query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
150
+ key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
151
+ value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
140
152
 
141
153
  if self._use_memory_efficient_attention_xformers:
142
154
  # Memory efficient attention
143
155
  hidden_states = xformers.ops.memory_efficient_attention(
144
- query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
156
+ query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
157
+ )
158
+ hidden_states = hidden_states.to(query_proj.dtype)
159
+ elif use_torch_2_0_attn:
160
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
161
+ # TODO: add support for attn.scale when we move to Torch 2.1
162
+ hidden_states = F.scaled_dot_product_attention(
163
+ query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
145
164
  )
146
165
  hidden_states = hidden_states.to(query_proj.dtype)
147
166
  else:
@@ -162,7 +181,7 @@ class AttentionBlock(nn.Module):
162
181
  hidden_states = torch.bmm(attention_probs, value_proj)
163
182
 
164
183
  # reshape hidden_states
165
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
184
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
166
185
 
167
186
  # compute next hidden_states
168
187
  hidden_states = self.proj_attn(hidden_states)