diffusers 0.15.1__py3-none-any.whl → 0.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -2
- diffusers/configuration_utils.py +4 -0
- diffusers/loaders.py +262 -12
- diffusers/models/attention.py +31 -12
- diffusers/models/attention_processor.py +189 -0
- diffusers/models/controlnet.py +9 -2
- diffusers/models/embeddings.py +66 -0
- diffusers/models/modeling_pytorch_flax_utils.py +6 -0
- diffusers/models/modeling_utils.py +5 -2
- diffusers/models/transformer_2d.py +1 -1
- diffusers/models/unet_2d_condition.py +45 -6
- diffusers/models/vae.py +3 -0
- diffusers/pipelines/__init__.py +8 -0
- diffusers/pipelines/alt_diffusion/modeling_roberta_series.py +25 -10
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +8 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +8 -0
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/deepfloyd_if/__init__.py +54 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +854 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +979 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1097 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1098 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1208 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +947 -0
- diffusers/pipelines/deepfloyd_if/safety_checker.py +59 -0
- diffusers/pipelines/deepfloyd_if/timesteps.py +579 -0
- diffusers/pipelines/deepfloyd_if/watermark.py +46 -0
- diffusers/pipelines/pipeline_utils.py +54 -25
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +37 -20
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +59 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +22 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +34 -30
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +93 -10
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +45 -6
- diffusers/schedulers/scheduling_ddpm.py +63 -16
- diffusers/schedulers/scheduling_heun_discrete.py +51 -1
- diffusers/utils/__init__.py +4 -1
- diffusers/utils/dummy_torch_and_transformers_objects.py +80 -5
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/hub_utils.py +4 -1
- diffusers/utils/import_utils.py +41 -0
- diffusers/utils/pil_utils.py +24 -0
- diffusers/utils/testing_utils.py +10 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/METADATA +1 -1
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/RECORD +57 -47
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/LICENSE +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/WHEEL +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.15.1.dist-info → diffusers-0.16.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.16.1"
|
2
2
|
|
3
3
|
from .configuration_utils import ConfigMixin
|
4
4
|
from .utils import (
|
@@ -109,12 +109,17 @@ try:
|
|
109
109
|
except OptionalDependencyNotAvailable:
|
110
110
|
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
111
111
|
else:
|
112
|
-
from .loaders import TextualInversionLoaderMixin
|
113
112
|
from .pipelines import (
|
114
113
|
AltDiffusionImg2ImgPipeline,
|
115
114
|
AltDiffusionPipeline,
|
116
115
|
AudioLDMPipeline,
|
117
116
|
CycleDiffusionPipeline,
|
117
|
+
IFImg2ImgPipeline,
|
118
|
+
IFImg2ImgSuperResolutionPipeline,
|
119
|
+
IFInpaintingPipeline,
|
120
|
+
IFInpaintingSuperResolutionPipeline,
|
121
|
+
IFPipeline,
|
122
|
+
IFSuperResolutionPipeline,
|
118
123
|
LDMTextToImagePipeline,
|
119
124
|
PaintByExamplePipeline,
|
120
125
|
SemanticStableDiffusionPipeline,
|
diffusers/configuration_utils.py
CHANGED
@@ -109,6 +109,7 @@ class ConfigMixin:
|
|
109
109
|
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
110
110
|
# or solve in a more general way.
|
111
111
|
kwargs.pop("kwargs", None)
|
112
|
+
|
112
113
|
if not hasattr(self, "_internal_dict"):
|
113
114
|
internal_dict = kwargs
|
114
115
|
else:
|
@@ -550,6 +551,9 @@ class ConfigMixin:
|
|
550
551
|
return value
|
551
552
|
|
552
553
|
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
554
|
+
# Don't save "_ignore_files"
|
555
|
+
config_dict.pop("_ignore_files", None)
|
556
|
+
|
553
557
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
554
558
|
|
555
559
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
diffusers/loaders.py
CHANGED
@@ -13,11 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
import os
|
15
15
|
from collections import defaultdict
|
16
|
+
from pathlib import Path
|
16
17
|
from typing import Callable, Dict, List, Optional, Union
|
17
18
|
|
18
19
|
import torch
|
20
|
+
from huggingface_hub import hf_hub_download
|
19
21
|
|
20
|
-
from .models.attention_processor import
|
22
|
+
from .models.attention_processor import (
|
23
|
+
CustomDiffusionAttnProcessor,
|
24
|
+
CustomDiffusionXFormersAttnProcessor,
|
25
|
+
LoRAAttnProcessor,
|
26
|
+
)
|
21
27
|
from .utils import (
|
22
28
|
DIFFUSERS_CACHE,
|
23
29
|
HF_HUB_OFFLINE,
|
@@ -46,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
|
46
52
|
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
47
53
|
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
48
54
|
|
55
|
+
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
56
|
+
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
57
|
+
|
49
58
|
|
50
59
|
class AttnProcsLayers(torch.nn.Module):
|
51
60
|
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
@@ -213,6 +222,7 @@ class UNet2DConditionLoadersMixin:
|
|
213
222
|
attn_processors = {}
|
214
223
|
|
215
224
|
is_lora = all("lora" in k for k in state_dict.keys())
|
225
|
+
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
216
226
|
|
217
227
|
if is_lora:
|
218
228
|
lora_grouped_dict = defaultdict(dict)
|
@@ -229,9 +239,38 @@ class UNet2DConditionLoadersMixin:
|
|
229
239
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
230
240
|
)
|
231
241
|
attn_processors[key].load_state_dict(value_dict)
|
232
|
-
|
242
|
+
elif is_custom_diffusion:
|
243
|
+
custom_diffusion_grouped_dict = defaultdict(dict)
|
244
|
+
for key, value in state_dict.items():
|
245
|
+
if len(value) == 0:
|
246
|
+
custom_diffusion_grouped_dict[key] = {}
|
247
|
+
else:
|
248
|
+
if "to_out" in key:
|
249
|
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
250
|
+
else:
|
251
|
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
252
|
+
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
253
|
+
|
254
|
+
for key, value_dict in custom_diffusion_grouped_dict.items():
|
255
|
+
if len(value_dict) == 0:
|
256
|
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
257
|
+
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
261
|
+
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
262
|
+
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
263
|
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
264
|
+
train_kv=True,
|
265
|
+
train_q_out=train_q_out,
|
266
|
+
hidden_size=hidden_size,
|
267
|
+
cross_attention_dim=cross_attention_dim,
|
268
|
+
)
|
269
|
+
attn_processors[key].load_state_dict(value_dict)
|
233
270
|
else:
|
234
|
-
raise ValueError(
|
271
|
+
raise ValueError(
|
272
|
+
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
273
|
+
)
|
235
274
|
|
236
275
|
# set correct dtype & device
|
237
276
|
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
@@ -285,16 +324,31 @@ class UNet2DConditionLoadersMixin:
|
|
285
324
|
|
286
325
|
os.makedirs(save_directory, exist_ok=True)
|
287
326
|
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
327
|
+
is_custom_diffusion = any(
|
328
|
+
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
329
|
+
for (_, x) in self.attn_processors.items()
|
330
|
+
)
|
331
|
+
if is_custom_diffusion:
|
332
|
+
model_to_save = AttnProcsLayers(
|
333
|
+
{
|
334
|
+
y: x
|
335
|
+
for (y, x) in self.attn_processors.items()
|
336
|
+
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
337
|
+
}
|
338
|
+
)
|
339
|
+
state_dict = model_to_save.state_dict()
|
340
|
+
for name, attn in self.attn_processors.items():
|
341
|
+
if len(attn.state_dict()) == 0:
|
342
|
+
state_dict[name] = {}
|
343
|
+
else:
|
344
|
+
model_to_save = AttnProcsLayers(self.attn_processors)
|
345
|
+
state_dict = model_to_save.state_dict()
|
292
346
|
|
293
347
|
if weight_name is None:
|
294
348
|
if safe_serialization:
|
295
|
-
weight_name = LORA_WEIGHT_NAME_SAFE
|
349
|
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
296
350
|
else:
|
297
|
-
weight_name = LORA_WEIGHT_NAME
|
351
|
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
298
352
|
|
299
353
|
# Save the model
|
300
354
|
save_function(state_dict, os.path.join(save_directory, weight_name))
|
@@ -356,7 +410,7 @@ class TextualInversionLoaderMixin:
|
|
356
410
|
replacement = token
|
357
411
|
i = 1
|
358
412
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
359
|
-
replacement += f"{token}_{i}"
|
413
|
+
replacement += f" {token}_{i}"
|
360
414
|
i += 1
|
361
415
|
|
362
416
|
prompt = prompt.replace(token, replacement)
|
@@ -431,6 +485,7 @@ class TextualInversionLoaderMixin:
|
|
431
485
|
Example:
|
432
486
|
|
433
487
|
To load a textual inversion embedding vector in `diffusers` format:
|
488
|
+
|
434
489
|
```py
|
435
490
|
from diffusers import StableDiffusionPipeline
|
436
491
|
import torch
|
@@ -456,13 +511,14 @@ class TextualInversionLoaderMixin:
|
|
456
511
|
model_id = "runwayml/stable-diffusion-v1-5"
|
457
512
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
458
513
|
|
459
|
-
pipe.load_textual_inversion("./charturnerv2.pt")
|
514
|
+
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
|
460
515
|
|
461
516
|
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
|
462
517
|
|
463
518
|
image = pipe(prompt, num_inference_steps=50).images[0]
|
464
519
|
image.save("character.png")
|
465
520
|
```
|
521
|
+
|
466
522
|
"""
|
467
523
|
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
468
524
|
raise ValueError(
|
@@ -792,7 +848,7 @@ class LoraLoaderMixin:
|
|
792
848
|
"""
|
793
849
|
# Loop over the original attention modules.
|
794
850
|
for name, _ in self.text_encoder.named_modules():
|
795
|
-
if any(
|
851
|
+
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
796
852
|
# Retrieve the module and its corresponding LoRA processor.
|
797
853
|
module = self.text_encoder.get_submodule(name)
|
798
854
|
# Construct a new function that performs the LoRA merging. We will monkey patch
|
@@ -1051,3 +1107,197 @@ class LoraLoaderMixin:
|
|
1051
1107
|
|
1052
1108
|
save_function(state_dict, os.path.join(save_directory, weight_name))
|
1053
1109
|
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
1110
|
+
|
1111
|
+
|
1112
|
+
class FromCkptMixin:
|
1113
|
+
"""This helper class allows to directly load .ckpt stable diffusion file_extension
|
1114
|
+
into the respective classes."""
|
1115
|
+
|
1116
|
+
@classmethod
|
1117
|
+
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
1118
|
+
r"""
|
1119
|
+
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights saved in the original .ckpt format.
|
1120
|
+
|
1121
|
+
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
1122
|
+
|
1123
|
+
Parameters:
|
1124
|
+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
1125
|
+
Can be either:
|
1126
|
+
- A link to the .ckpt file on the Hub. Should be in the format
|
1127
|
+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>"`
|
1128
|
+
- A path to a *file* containing all pipeline weights.
|
1129
|
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
1130
|
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
1131
|
+
will be automatically derived from the model's weights.
|
1132
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
1133
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1134
|
+
cached versions if they exist.
|
1135
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1136
|
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
1137
|
+
standard cache should not be used.
|
1138
|
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1139
|
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
1140
|
+
file exists.
|
1141
|
+
proxies (`Dict[str, str]`, *optional*):
|
1142
|
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
1143
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1144
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1145
|
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
1146
|
+
use_auth_token (`str` or *bool*, *optional*):
|
1147
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
1148
|
+
when running `huggingface-cli login` (stored in `~/.huggingface`).
|
1149
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
1150
|
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
1151
|
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
1152
|
+
identifier allowed by git.
|
1153
|
+
use_safetensors (`bool`, *optional* ):
|
1154
|
+
If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the
|
1155
|
+
default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if
|
1156
|
+
`safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`.
|
1157
|
+
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
1158
|
+
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults
|
1159
|
+
to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
1160
|
+
inference. Non-EMA weights are usually better to continue fine-tuning.
|
1161
|
+
upcast_attention (`bool`, *optional*, defaults to `None`):
|
1162
|
+
Whether the attention computation should always be upcasted. This is necessary when running stable
|
1163
|
+
image_size (`int`, *optional*, defaults to 512):
|
1164
|
+
The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
1165
|
+
Base. Use 768 for Stable Diffusion v2.
|
1166
|
+
prediction_type (`str`, *optional*):
|
1167
|
+
The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion v1.X and Stable
|
1168
|
+
Diffusion v2 Base. Use `'v_prediction'` for Stable Diffusion v2.
|
1169
|
+
num_in_channels (`int`, *optional*, defaults to None):
|
1170
|
+
The number of input channels. If `None`, it will be automatically inferred.
|
1171
|
+
scheduler_type (`str`, *optional*, defaults to 'pndm'):
|
1172
|
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
1173
|
+
"ddim"]`.
|
1174
|
+
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1175
|
+
Whether to load the safety checker or not. Defaults to `True`.
|
1176
|
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
1177
|
+
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
1178
|
+
specific pipeline class. The overwritten components are then directly passed to the pipelines
|
1179
|
+
`__init__` method. See example below for more information.
|
1180
|
+
|
1181
|
+
Examples:
|
1182
|
+
|
1183
|
+
```py
|
1184
|
+
>>> from diffusers import StableDiffusionPipeline
|
1185
|
+
|
1186
|
+
>>> # Download pipeline from huggingface.co and cache.
|
1187
|
+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
1188
|
+
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
1189
|
+
... )
|
1190
|
+
|
1191
|
+
>>> # Download pipeline from local file
|
1192
|
+
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
1193
|
+
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
1194
|
+
|
1195
|
+
>>> # Enable float16 and move to GPU
|
1196
|
+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
1197
|
+
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
1198
|
+
... torch_dtype=torch.float16,
|
1199
|
+
... )
|
1200
|
+
>>> pipeline.to("cuda")
|
1201
|
+
```
|
1202
|
+
"""
|
1203
|
+
# import here to avoid circular dependency
|
1204
|
+
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
1205
|
+
|
1206
|
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1207
|
+
resume_download = kwargs.pop("resume_download", False)
|
1208
|
+
force_download = kwargs.pop("force_download", False)
|
1209
|
+
proxies = kwargs.pop("proxies", None)
|
1210
|
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1211
|
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1212
|
+
revision = kwargs.pop("revision", None)
|
1213
|
+
extract_ema = kwargs.pop("extract_ema", False)
|
1214
|
+
image_size = kwargs.pop("image_size", 512)
|
1215
|
+
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
1216
|
+
num_in_channels = kwargs.pop("num_in_channels", None)
|
1217
|
+
upcast_attention = kwargs.pop("upcast_attention", None)
|
1218
|
+
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
1219
|
+
prediction_type = kwargs.pop("prediction_type", None)
|
1220
|
+
|
1221
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1222
|
+
|
1223
|
+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
1224
|
+
|
1225
|
+
pipeline_name = cls.__name__
|
1226
|
+
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
1227
|
+
from_safetensors = file_extension == "safetensors"
|
1228
|
+
|
1229
|
+
if from_safetensors and use_safetensors is True:
|
1230
|
+
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
1231
|
+
|
1232
|
+
# TODO: For now we only support stable diffusion
|
1233
|
+
stable_unclip = None
|
1234
|
+
controlnet = False
|
1235
|
+
|
1236
|
+
if pipeline_name == "StableDiffusionControlNetPipeline":
|
1237
|
+
model_type = "FrozenCLIPEmbedder"
|
1238
|
+
controlnet = True
|
1239
|
+
elif "StableDiffusion" in pipeline_name:
|
1240
|
+
model_type = "FrozenCLIPEmbedder"
|
1241
|
+
elif pipeline_name == "StableUnCLIPPipeline":
|
1242
|
+
model_type == "FrozenOpenCLIPEmbedder"
|
1243
|
+
stable_unclip = "txt2img"
|
1244
|
+
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
1245
|
+
model_type == "FrozenOpenCLIPEmbedder"
|
1246
|
+
stable_unclip = "img2img"
|
1247
|
+
elif pipeline_name == "PaintByExamplePipeline":
|
1248
|
+
model_type == "PaintByExample"
|
1249
|
+
elif pipeline_name == "LDMTextToImagePipeline":
|
1250
|
+
model_type == "LDMTextToImage"
|
1251
|
+
else:
|
1252
|
+
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
1253
|
+
|
1254
|
+
# remove huggingface url
|
1255
|
+
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
1256
|
+
if pretrained_model_link_or_path.startswith(prefix):
|
1257
|
+
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
1258
|
+
|
1259
|
+
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
1260
|
+
ckpt_path = Path(pretrained_model_link_or_path)
|
1261
|
+
if not ckpt_path.is_file():
|
1262
|
+
# get repo_id and (potentially nested) file path of ckpt in repo
|
1263
|
+
repo_id = str(Path().joinpath(*ckpt_path.parts[:2]))
|
1264
|
+
file_path = str(Path().joinpath(*ckpt_path.parts[2:]))
|
1265
|
+
|
1266
|
+
if file_path.startswith("blob/"):
|
1267
|
+
file_path = file_path[len("blob/") :]
|
1268
|
+
|
1269
|
+
if file_path.startswith("main/"):
|
1270
|
+
file_path = file_path[len("main/") :]
|
1271
|
+
|
1272
|
+
pretrained_model_link_or_path = hf_hub_download(
|
1273
|
+
repo_id,
|
1274
|
+
filename=file_path,
|
1275
|
+
cache_dir=cache_dir,
|
1276
|
+
resume_download=resume_download,
|
1277
|
+
proxies=proxies,
|
1278
|
+
local_files_only=local_files_only,
|
1279
|
+
use_auth_token=use_auth_token,
|
1280
|
+
revision=revision,
|
1281
|
+
force_download=force_download,
|
1282
|
+
)
|
1283
|
+
|
1284
|
+
pipe = download_from_original_stable_diffusion_ckpt(
|
1285
|
+
pretrained_model_link_or_path,
|
1286
|
+
pipeline_class=cls,
|
1287
|
+
model_type=model_type,
|
1288
|
+
stable_unclip=stable_unclip,
|
1289
|
+
controlnet=controlnet,
|
1290
|
+
from_safetensors=from_safetensors,
|
1291
|
+
extract_ema=extract_ema,
|
1292
|
+
image_size=image_size,
|
1293
|
+
scheduler_type=scheduler_type,
|
1294
|
+
num_in_channels=num_in_channels,
|
1295
|
+
upcast_attention=upcast_attention,
|
1296
|
+
load_safety_checker=load_safety_checker,
|
1297
|
+
prediction_type=prediction_type,
|
1298
|
+
)
|
1299
|
+
|
1300
|
+
if torch_dtype is not None:
|
1301
|
+
pipe.to(torch_dtype=torch_dtype)
|
1302
|
+
|
1303
|
+
return pipe
|
diffusers/models/attention.py
CHANGED
@@ -60,7 +60,6 @@ class AttentionBlock(nn.Module):
|
|
60
60
|
self.channels = channels
|
61
61
|
|
62
62
|
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
63
|
-
self.num_head_size = num_head_channels
|
64
63
|
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
65
64
|
|
66
65
|
# define q,k,v as linear layers
|
@@ -72,20 +71,30 @@ class AttentionBlock(nn.Module):
|
|
72
71
|
self.proj_attn = nn.Linear(channels, channels, bias=True)
|
73
72
|
|
74
73
|
self._use_memory_efficient_attention_xformers = False
|
74
|
+
self._use_2_0_attn = True
|
75
75
|
self._attention_op = None
|
76
76
|
|
77
|
-
def reshape_heads_to_batch_dim(self, tensor):
|
77
|
+
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
|
78
78
|
batch_size, seq_len, dim = tensor.shape
|
79
79
|
head_size = self.num_heads
|
80
80
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
81
|
-
tensor = tensor.permute(0, 2, 1, 3)
|
81
|
+
tensor = tensor.permute(0, 2, 1, 3)
|
82
|
+
if merge_head_and_batch:
|
83
|
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
82
84
|
return tensor
|
83
85
|
|
84
|
-
def reshape_batch_dim_to_heads(self, tensor):
|
85
|
-
batch_size, seq_len, dim = tensor.shape
|
86
|
+
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
|
86
87
|
head_size = self.num_heads
|
87
|
-
|
88
|
-
|
88
|
+
|
89
|
+
if unmerge_head_and_batch:
|
90
|
+
batch_head_size, seq_len, dim = tensor.shape
|
91
|
+
batch_size = batch_head_size // head_size
|
92
|
+
|
93
|
+
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
|
94
|
+
else:
|
95
|
+
batch_size, _, seq_len, dim = tensor.shape
|
96
|
+
|
97
|
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
|
89
98
|
return tensor
|
90
99
|
|
91
100
|
def set_use_memory_efficient_attention_xformers(
|
@@ -134,14 +143,24 @@ class AttentionBlock(nn.Module):
|
|
134
143
|
|
135
144
|
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
136
145
|
|
137
|
-
|
138
|
-
|
139
|
-
|
146
|
+
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
|
147
|
+
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
|
148
|
+
|
149
|
+
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
150
|
+
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
151
|
+
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
|
140
152
|
|
141
153
|
if self._use_memory_efficient_attention_xformers:
|
142
154
|
# Memory efficient attention
|
143
155
|
hidden_states = xformers.ops.memory_efficient_attention(
|
144
|
-
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
156
|
+
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
|
157
|
+
)
|
158
|
+
hidden_states = hidden_states.to(query_proj.dtype)
|
159
|
+
elif use_torch_2_0_attn:
|
160
|
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
161
|
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
162
|
+
hidden_states = F.scaled_dot_product_attention(
|
163
|
+
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
|
145
164
|
)
|
146
165
|
hidden_states = hidden_states.to(query_proj.dtype)
|
147
166
|
else:
|
@@ -162,7 +181,7 @@ class AttentionBlock(nn.Module):
|
|
162
181
|
hidden_states = torch.bmm(attention_probs, value_proj)
|
163
182
|
|
164
183
|
# reshape hidden_states
|
165
|
-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
184
|
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
|
166
185
|
|
167
186
|
# compute next hidden_states
|
168
187
|
hidden_states = self.proj_attn(hidden_states)
|