diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. diffusers/__init__.py +7 -2
  2. diffusers/configuration_utils.py +4 -0
  3. diffusers/loaders.py +262 -12
  4. diffusers/models/attention.py +31 -12
  5. diffusers/models/attention_processor.py +189 -0
  6. diffusers/models/controlnet.py +9 -2
  7. diffusers/models/embeddings.py +66 -0
  8. diffusers/models/modeling_pytorch_flax_utils.py +6 -0
  9. diffusers/models/modeling_utils.py +5 -2
  10. diffusers/models/transformer_2d.py +1 -1
  11. diffusers/models/unet_2d_condition.py +45 -6
  12. diffusers/models/vae.py +3 -0
  13. diffusers/pipelines/__init__.py +8 -0
  14. diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
  15. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
  16. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
  17. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  18. diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
  19. diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
  20. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
  21. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
  22. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
  23. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
  24. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
  25. diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
  26. diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
  27. diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
  28. diffusers/pipelines/pipeline_utils.py +54 -25
  29. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
  30. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
  31. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
  32. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
  33. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
  34. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
  35. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
  36. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
  37. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
  38. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
  39. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
  40. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
  41. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
  42. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
  43. diffusers/schedulers/scheduling_ddpm.py +63 -16
  44. diffusers/schedulers/scheduling_heun_discrete.py +51 -1
  45. diffusers/utils/__init__.py +4 -1
  46. diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
  47. diffusers/utils/dynamic_modules_utils.py +1 -1
  48. diffusers/utils/hub_utils.py +4 -1
  49. diffusers/utils/import_utils.py +41 -0
  50. diffusers/utils/pil_utils.py +24 -0
  51. diffusers/utils/testing_utils.py +10 -0
  52. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
  53. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
  54. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
  55. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
  56. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
  57. {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.15.1"
1
+ __version__ = "0.16.1"
2
2
 
3
3
  from .configuration_utils import ConfigMixin
4
4
  from .utils import (
@@ -109,12 +109,17 @@ try:
109
109
  except OptionalDependencyNotAvailable:
110
110
  from .utils.dummy_torch_and_transformers_objects import * # noqa F403
111
111
  else:
112
- from .loaders import TextualInversionLoaderMixin
113
112
  from .pipelines import (
114
113
  AltDiffusionImg2ImgPipeline,
115
114
  AltDiffusionPipeline,
116
115
  AudioLDMPipeline,
117
116
  CycleDiffusionPipeline,
117
+ IFImg2ImgPipeline,
118
+ IFImg2ImgSuperResolutionPipeline,
119
+ IFInpaintingPipeline,
120
+ IFInpaintingSuperResolutionPipeline,
121
+ IFPipeline,
122
+ IFSuperResolutionPipeline,
118
123
  LDMTextToImagePipeline,
119
124
  PaintByExamplePipeline,
120
125
  SemanticStableDiffusionPipeline,
@@ -109,6 +109,7 @@ class ConfigMixin:
109
109
  # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
110
110
  # or solve in a more general way.
111
111
  kwargs.pop("kwargs", None)
112
+
112
113
  if not hasattr(self, "_internal_dict"):
113
114
  internal_dict = kwargs
114
115
  else:
@@ -550,6 +551,9 @@ class ConfigMixin:
550
551
  return value
551
552
 
552
553
  config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
554
+ # Don't save "_ignore_files"
555
+ config_dict.pop("_ignore_files", None)
556
+
553
557
  return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
554
558
 
555
559
  def to_json_file(self, json_file_path: Union[str, os.PathLike]):
diffusers/loaders.py CHANGED
@@ -13,11 +13,17 @@
13
13
  # limitations under the License.
14
14
  import os
15
15
  from collections import defaultdict
16
+ from pathlib import Path
16
17
  from typing import Callable, Dict, List, Optional, Union
17
18
 
18
19
  import torch
20
+ from huggingface_hub import hf_hub_download
19
21
 
20
- from .models.attention_processor import LoRAAttnProcessor
22
+ from .models.attention_processor import (
23
+ CustomDiffusionAttnProcessor,
24
+ CustomDiffusionXFormersAttnProcessor,
25
+ LoRAAttnProcessor,
26
+ )
21
27
  from .utils import (
22
28
  DIFFUSERS_CACHE,
23
29
  HF_HUB_OFFLINE,
@@ -46,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
46
52
  TEXT_INVERSION_NAME = "learned_embeds.bin"
47
53
  TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
48
54
 
55
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
56
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
57
+
49
58
 
50
59
  class AttnProcsLayers(torch.nn.Module):
51
60
  def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -213,6 +222,7 @@ class UNet2DConditionLoadersMixin:
213
222
  attn_processors = {}
214
223
 
215
224
  is_lora = all("lora" in k for k in state_dict.keys())
225
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
216
226
 
217
227
  if is_lora:
218
228
  lora_grouped_dict = defaultdict(dict)
@@ -229,9 +239,38 @@ class UNet2DConditionLoadersMixin:
229
239
  hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
230
240
  )
231
241
  attn_processors[key].load_state_dict(value_dict)
232
-
242
+ elif is_custom_diffusion:
243
+ custom_diffusion_grouped_dict = defaultdict(dict)
244
+ for key, value in state_dict.items():
245
+ if len(value) == 0:
246
+ custom_diffusion_grouped_dict[key] = {}
247
+ else:
248
+ if "to_out" in key:
249
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
250
+ else:
251
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
252
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
253
+
254
+ for key, value_dict in custom_diffusion_grouped_dict.items():
255
+ if len(value_dict) == 0:
256
+ attn_processors[key] = CustomDiffusionAttnProcessor(
257
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
258
+ )
259
+ else:
260
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
261
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
262
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
263
+ attn_processors[key] = CustomDiffusionAttnProcessor(
264
+ train_kv=True,
265
+ train_q_out=train_q_out,
266
+ hidden_size=hidden_size,
267
+ cross_attention_dim=cross_attention_dim,
268
+ )
269
+ attn_processors[key].load_state_dict(value_dict)
233
270
  else:
234
- raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
271
+ raise ValueError(
272
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
273
+ )
235
274
 
236
275
  # set correct dtype & device
237
276
  attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
@@ -285,16 +324,31 @@ class UNet2DConditionLoadersMixin:
285
324
 
286
325
  os.makedirs(save_directory, exist_ok=True)
287
326
 
288
- model_to_save = AttnProcsLayers(self.attn_processors)
289
-
290
- # Save the model
291
- state_dict = model_to_save.state_dict()
327
+ is_custom_diffusion = any(
328
+ isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
329
+ for (_, x) in self.attn_processors.items()
330
+ )
331
+ if is_custom_diffusion:
332
+ model_to_save = AttnProcsLayers(
333
+ {
334
+ y: x
335
+ for (y, x) in self.attn_processors.items()
336
+ if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
337
+ }
338
+ )
339
+ state_dict = model_to_save.state_dict()
340
+ for name, attn in self.attn_processors.items():
341
+ if len(attn.state_dict()) == 0:
342
+ state_dict[name] = {}
343
+ else:
344
+ model_to_save = AttnProcsLayers(self.attn_processors)
345
+ state_dict = model_to_save.state_dict()
292
346
 
293
347
  if weight_name is None:
294
348
  if safe_serialization:
295
- weight_name = LORA_WEIGHT_NAME_SAFE
349
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
296
350
  else:
297
- weight_name = LORA_WEIGHT_NAME
351
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
298
352
 
299
353
  # Save the model
300
354
  save_function(state_dict, os.path.join(save_directory, weight_name))
@@ -356,7 +410,7 @@ class TextualInversionLoaderMixin:
356
410
  replacement = token
357
411
  i = 1
358
412
  while f"{token}_{i}" in tokenizer.added_tokens_encoder:
359
- replacement += f"{token}_{i}"
413
+ replacement += f" {token}_{i}"
360
414
  i += 1
361
415
 
362
416
  prompt = prompt.replace(token, replacement)
@@ -431,6 +485,7 @@ class TextualInversionLoaderMixin:
431
485
  Example:
432
486
 
433
487
  To load a textual inversion embedding vector in `diffusers` format:
488
+
434
489
  ```py
435
490
  from diffusers import StableDiffusionPipeline
436
491
  import torch
@@ -456,13 +511,14 @@ class TextualInversionLoaderMixin:
456
511
  model_id = "runwayml/stable-diffusion-v1-5"
457
512
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
458
513
 
459
- pipe.load_textual_inversion("./charturnerv2.pt")
514
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
460
515
 
461
516
  prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
462
517
 
463
518
  image = pipe(prompt, num_inference_steps=50).images[0]
464
519
  image.save("character.png")
465
520
  ```
521
+
466
522
  """
467
523
  if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
468
524
  raise ValueError(
@@ -792,7 +848,7 @@ class LoraLoaderMixin:
792
848
  """
793
849
  # Loop over the original attention modules.
794
850
  for name, _ in self.text_encoder.named_modules():
795
- if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
851
+ if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
796
852
  # Retrieve the module and its corresponding LoRA processor.
797
853
  module = self.text_encoder.get_submodule(name)
798
854
  # Construct a new function that performs the LoRA merging. We will monkey patch
@@ -1051,3 +1107,197 @@ class LoraLoaderMixin:
1051
1107
 
1052
1108
  save_function(state_dict, os.path.join(save_directory, weight_name))
1053
1109
  logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1110
+
1111
+
1112
+ class FromCkptMixin:
1113
+ """This helper class allows to directly load .ckpt stable diffusion file_extension
1114
+ into the respective classes."""
1115
+
1116
+ @classmethod
1117
+ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
1118
+ r"""
1119
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
1120
+
1121
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
1122
+
1123
+ Parameters:
1124
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1125
+ Can be either:
1126
+ - A link to the .ckpt file on the Hub. Should be in the format
1127
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
1128
+ - A path to a *file* containing all pipeline weights.
1129
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1130
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
1131
+ will be automatically derived from the model's weights.
1132
+ force_download (`bool`, *optional*, defaults to `False`):
1133
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1134
+ cached versions if they exist.
1135
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1136
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
1137
+ standard cache should not be used.
1138
+ resume_download (`bool`, *optional*, defaults to `False`):
1139
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1140
+ file exists.
1141
+ proxies (`Dict[str, str]`, *optional*):
1142
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1143
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1144
+ local_files_only (`bool`, *optional*, defaults to `False`):
1145
+ Whether or not to only look at local files (i.e., do not try to download the model).
1146
+ use_auth_token (`str` or *bool*, *optional*):
1147
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1148
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
1149
+ revision (`str`, *optional*, defaults to `"main"`):
1150
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1151
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1152
+ identifier allowed by git.
1153
+ use_safetensors (`bool`, *optional* ):
1154
+ If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
1155
+ default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
1156
+ `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
1157
+ extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
1158
+ checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
1159
+ to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
1160
+ inference. Non-EMA weights are usually better to continue fine-tuning.
1161
+ upcast_attention (`bool`, *optional*, defaults to `None`):
1162
+ Whether the attention computation should always be upcasted. This is necessary when running stable
1163
+ image_size (`int`, *optional*, defaults to 512):
1164
+ The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
1165
+ Base. Use 768 for Stable Diffusion v2.
1166
+ prediction_type (`str`, *optional*):
1167
+ The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
1168
+ Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
1169
+ num_in_channels (`int`, *optional*, defaults to None):
1170
+ The number of input channels. If `None`, it will be automatically inferred.
1171
+ scheduler_type (`str`, *optional*, defaults to 'pndm'):
1172
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1173
+ "ddim"]`.
1174
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
1175
+ Whether to load the safety checker or not. Defaults to `True`.
1176
+ kwargs (remaining dictionary of keyword arguments, *optional*):
1177
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
1178
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
1179
+ `__init__` method. See example below for more information.
1180
+
1181
+ Examples:
1182
+
1183
+ ```py
1184
+ >>> from diffusers import StableDiffusionPipeline
1185
+
1186
+ >>> # Download pipeline from huggingface.co and cache.
1187
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1188
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1189
+ ... )
1190
+
1191
+ >>> # Download pipeline from local file
1192
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1193
+ >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
1194
+
1195
+ >>> # Enable float16 and move to GPU
1196
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1197
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1198
+ ... torch_dtype=torch.float16,
1199
+ ... )
1200
+ >>> pipeline.to("cuda")
1201
+ ```
1202
+ """
1203
+ # import here to avoid circular dependency
1204
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1205
+
1206
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1207
+ resume_download = kwargs.pop("resume_download", False)
1208
+ force_download = kwargs.pop("force_download", False)
1209
+ proxies = kwargs.pop("proxies", None)
1210
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1211
+ use_auth_token = kwargs.pop("use_auth_token", None)
1212
+ revision = kwargs.pop("revision", None)
1213
+ extract_ema = kwargs.pop("extract_ema", False)
1214
+ image_size = kwargs.pop("image_size", 512)
1215
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
1216
+ num_in_channels = kwargs.pop("num_in_channels", None)
1217
+ upcast_attention = kwargs.pop("upcast_attention", None)
1218
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
1219
+ prediction_type = kwargs.pop("prediction_type", None)
1220
+
1221
+ torch_dtype = kwargs.pop("torch_dtype", None)
1222
+
1223
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1224
+
1225
+ pipeline_name = cls.__name__
1226
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1227
+ from_safetensors = file_extension == "safetensors"
1228
+
1229
+ if from_safetensors and use_safetensors is True:
1230
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1231
+
1232
+ # TODO: For now we only support stable diffusion
1233
+ stable_unclip = None
1234
+ controlnet = False
1235
+
1236
+ if pipeline_name == "StableDiffusionControlNetPipeline":
1237
+ model_type = "FrozenCLIPEmbedder"
1238
+ controlnet = True
1239
+ elif "StableDiffusion" in pipeline_name:
1240
+ model_type = "FrozenCLIPEmbedder"
1241
+ elif pipeline_name == "StableUnCLIPPipeline":
1242
+ model_type == "FrozenOpenCLIPEmbedder"
1243
+ stable_unclip = "txt2img"
1244
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1245
+ model_type == "FrozenOpenCLIPEmbedder"
1246
+ stable_unclip = "img2img"
1247
+ elif pipeline_name == "PaintByExamplePipeline":
1248
+ model_type == "PaintByExample"
1249
+ elif pipeline_name == "LDMTextToImagePipeline":
1250
+ model_type == "LDMTextToImage"
1251
+ else:
1252
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1253
+
1254
+ # remove huggingface url
1255
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1256
+ if pretrained_model_link_or_path.startswith(prefix):
1257
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1258
+
1259
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1260
+ ckpt_path = Path(pretrained_model_link_or_path)
1261
+ if not ckpt_path.is_file():
1262
+ # get repo_id and (potentially nested) file path of ckpt in repo
1263
+ repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
1264
+ file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
1265
+
1266
+ if file_path.startswith("blob/"):
1267
+ file_path = file_path[len("blob/") :]
1268
+
1269
+ if file_path.startswith("main/"):
1270
+ file_path = file_path[len("main/") :]
1271
+
1272
+ pretrained_model_link_or_path = hf_hub_download(
1273
+ repo_id,
1274
+ filename=file_path,
1275
+ cache_dir=cache_dir,
1276
+ resume_download=resume_download,
1277
+ proxies=proxies,
1278
+ local_files_only=local_files_only,
1279
+ use_auth_token=use_auth_token,
1280
+ revision=revision,
1281
+ force_download=force_download,
1282
+ )
1283
+
1284
+ pipe = download_from_original_stable_diffusion_ckpt(
1285
+ pretrained_model_link_or_path,
1286
+ pipeline_class=cls,
1287
+ model_type=model_type,
1288
+ stable_unclip=stable_unclip,
1289
+ controlnet=controlnet,
1290
+ from_safetensors=from_safetensors,
1291
+ extract_ema=extract_ema,
1292
+ image_size=image_size,
1293
+ scheduler_type=scheduler_type,
1294
+ num_in_channels=num_in_channels,
1295
+ upcast_attention=upcast_attention,
1296
+ load_safety_checker=load_safety_checker,
1297
+ prediction_type=prediction_type,
1298
+ )
1299
+
1300
+ if torch_dtype is not None:
1301
+ pipe.to(torch_dtype=torch_dtype)
1302
+
1303
+ return pipe
@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
60
60
  self.channels = channels
61
61
 
62
62
  self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
- self.num_head_size = num_head_channels
64
63
  self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65
64
 
66
65
  # define q,k,v as linear layers
@@ -72,20 +71,30 @@ class AttentionBlock(nn.Module):
72
71
  self.proj_attn = nn.Linear(channels, channels, bias=True)
73
72
 
74
73
  self._use_memory_efficient_attention_xformers = False
74
+ self._use_2_0_attn = True
75
75
  self._attention_op = None
76
76
 
77
- def reshape_heads_to_batch_dim(self, tensor):
77
+ def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
78
78
  batch_size, seq_len, dim = tensor.shape
79
79
  head_size = self.num_heads
80
80
  tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
81
+ tensor = tensor.permute(0, 2, 1, 3)
82
+ if merge_head_and_batch:
83
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
82
84
  return tensor
83
85
 
84
- def reshape_batch_dim_to_heads(self, tensor):
85
- batch_size, seq_len, dim = tensor.shape
86
+ def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
86
87
  head_size = self.num_heads
87
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
88
+
89
+ if unmerge_head_and_batch:
90
+ batch_head_size, seq_len, dim = tensor.shape
91
+ batch_size = batch_head_size // head_size
92
+
93
+ tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
94
+ else:
95
+ batch_size, _, seq_len, dim = tensor.shape
96
+
97
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
89
98
  return tensor
90
99
 
91
100
  def set_use_memory_efficient_attention_xformers(
@@ -134,14 +143,24 @@ class AttentionBlock(nn.Module):
134
143
 
135
144
  scale = 1 / math.sqrt(self.channels / self.num_heads)
136
145
 
137
- query_proj = self.reshape_heads_to_batch_dim(query_proj)
138
- key_proj = self.reshape_heads_to_batch_dim(key_proj)
139
- value_proj = self.reshape_heads_to_batch_dim(value_proj)
146
+ _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
147
+ use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
148
+
149
+ query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
150
+ key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
151
+ value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
140
152
 
141
153
  if self._use_memory_efficient_attention_xformers:
142
154
  # Memory efficient attention
143
155
  hidden_states = xformers.ops.memory_efficient_attention(
144
- query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
156
+ query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
157
+ )
158
+ hidden_states = hidden_states.to(query_proj.dtype)
159
+ elif use_torch_2_0_attn:
160
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
161
+ # TODO: add support for attn.scale when we move to Torch 2.1
162
+ hidden_states = F.scaled_dot_product_attention(
163
+ query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
145
164
  )
146
165
  hidden_states = hidden_states.to(query_proj.dtype)
147
166
  else:
@@ -162,7 +181,7 @@ class AttentionBlock(nn.Module):
162
181
  hidden_states = torch.bmm(attention_probs, value_proj)
163
182
 
164
183
  # reshape hidden_states
165
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
184
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
166
185
 
167
186
  # compute next hidden_states
168
187
  hidden_states = self.proj_attn(hidden_states)