diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +3 -1
- diffusers/commands/fp16_safetensors.py +2 -7
- diffusers/configuration_utils.py +23 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/loaders.py +62 -64
- diffusers/models/__init__.py +1 -0
- diffusers/models/activations.py +2 -0
- diffusers/models/attention.py +45 -1
- diffusers/models/autoencoder_tiny.py +193 -0
- diffusers/models/controlnet.py +1 -1
- diffusers/models/embeddings.py +56 -0
- diffusers/models/lora.py +0 -6
- diffusers/models/modeling_flax_utils.py +28 -2
- diffusers/models/modeling_utils.py +33 -16
- diffusers/models/transformer_2d.py +26 -9
- diffusers/models/unet_1d.py +2 -2
- diffusers/models/unet_2d_blocks.py +106 -56
- diffusers/models/unet_2d_condition.py +20 -5
- diffusers/models/vae.py +106 -1
- diffusers/pipelines/__init__.py +1 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/auto_pipeline.py +33 -43
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/pipeline_flax_utils.py +41 -4
- diffusers/pipelines/pipeline_utils.py +60 -16
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
- diffusers/schedulers/scheduling_consistency_models.py +70 -57
- diffusers/schedulers/scheduling_ddim.py +76 -71
- diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
- diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
- diffusers/schedulers/scheduling_ddpm.py +68 -67
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
- diffusers/schedulers/scheduling_deis_multistep.py +93 -85
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
- diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
- diffusers/schedulers/scheduling_euler_discrete.py +63 -56
- diffusers/schedulers/scheduling_heun_discrete.py +57 -45
- diffusers/schedulers/scheduling_ipndm.py +27 -22
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
- diffusers/schedulers/scheduling_karras_ve.py +55 -45
- diffusers/schedulers/scheduling_lms_discrete.py +58 -52
- diffusers/schedulers/scheduling_pndm.py +77 -62
- diffusers/schedulers/scheduling_repaint.py +56 -38
- diffusers/schedulers/scheduling_sde_ve.py +62 -50
- diffusers/schedulers/scheduling_sde_vp.py +32 -11
- diffusers/schedulers/scheduling_unclip.py +3 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
- diffusers/schedulers/scheduling_utils.py +41 -35
- diffusers/schedulers/scheduling_utils_flax.py +8 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- diffusers/utils/hub_utils.py +105 -2
- diffusers/utils/import_utils.py +0 -4
- diffusers/utils/pil_utils.py +19 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
- diffusers/models/cross_attention.py +0 -94
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.20.1"
|
2
2
|
|
3
3
|
from .configuration_utils import ConfigMixin
|
4
4
|
from .utils import (
|
@@ -38,6 +38,7 @@ else:
|
|
38
38
|
from .models import (
|
39
39
|
AsymmetricAutoencoderKL,
|
40
40
|
AutoencoderKL,
|
41
|
+
AutoencoderTiny,
|
41
42
|
ControlNetModel,
|
42
43
|
ModelMixin,
|
43
44
|
MultiAdapter,
|
@@ -170,6 +171,7 @@ else:
|
|
170
171
|
StableDiffusionControlNetPipeline,
|
171
172
|
StableDiffusionDepth2ImgPipeline,
|
172
173
|
StableDiffusionDiffEditPipeline,
|
174
|
+
StableDiffusionGLIGENPipeline,
|
173
175
|
StableDiffusionImageVariationPipeline,
|
174
176
|
StableDiffusionImg2ImgPipeline,
|
175
177
|
StableDiffusionInpaintPipeline,
|
@@ -27,7 +27,7 @@ import torch
|
|
27
27
|
from huggingface_hub import hf_hub_download
|
28
28
|
from packaging import version
|
29
29
|
|
30
|
-
from ..utils import
|
30
|
+
from ..utils import logging
|
31
31
|
from . import BaseDiffusersCLICommand
|
32
32
|
|
33
33
|
|
@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
|
|
68
68
|
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
|
69
69
|
self.fp16 = fp16
|
70
70
|
|
71
|
-
|
72
|
-
self.use_safetensors = use_safetensors
|
73
|
-
else:
|
74
|
-
raise ImportError(
|
75
|
-
"When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
|
76
|
-
)
|
71
|
+
self.use_safetensors = use_safetensors
|
77
72
|
|
78
73
|
if not self.use_safetensors and not self.fp16:
|
79
74
|
raise NotImplementedError(
|
diffusers/configuration_utils.py
CHANGED
@@ -26,7 +26,7 @@ from pathlib import PosixPath
|
|
26
26
|
from typing import Any, Dict, Tuple, Union
|
27
27
|
|
28
28
|
import numpy as np
|
29
|
-
from huggingface_hub import hf_hub_download
|
29
|
+
from huggingface_hub import create_repo, hf_hub_download
|
30
30
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
31
31
|
from requests import HTTPError
|
32
32
|
|
@@ -144,6 +144,12 @@ class ConfigMixin:
|
|
144
144
|
Args:
|
145
145
|
save_directory (`str` or `os.PathLike`):
|
146
146
|
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
147
|
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
148
|
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
149
|
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
150
|
+
namespace).
|
151
|
+
kwargs (`Dict[str, Any]`, *optional*):
|
152
|
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
147
153
|
"""
|
148
154
|
if os.path.isfile(save_directory):
|
149
155
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
@@ -156,6 +162,22 @@ class ConfigMixin:
|
|
156
162
|
self.to_json_file(output_config_file)
|
157
163
|
logger.info(f"Configuration saved in {output_config_file}")
|
158
164
|
|
165
|
+
if push_to_hub:
|
166
|
+
commit_message = kwargs.pop("commit_message", None)
|
167
|
+
private = kwargs.pop("private", False)
|
168
|
+
create_pr = kwargs.pop("create_pr", False)
|
169
|
+
token = kwargs.pop("token", None)
|
170
|
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
171
|
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
172
|
+
|
173
|
+
self._upload_folder(
|
174
|
+
save_directory,
|
175
|
+
repo_id,
|
176
|
+
token=token,
|
177
|
+
commit_message=commit_message,
|
178
|
+
create_pr=create_pr,
|
179
|
+
)
|
180
|
+
|
159
181
|
@classmethod
|
160
182
|
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
161
183
|
r"""
|
@@ -29,7 +29,7 @@ deps = {
|
|
29
29
|
"pytest": "pytest",
|
30
30
|
"pytest-timeout": "pytest-timeout",
|
31
31
|
"pytest-xdist": "pytest-xdist",
|
32
|
-
"ruff": "ruff
|
32
|
+
"ruff": "ruff==0.0.280",
|
33
33
|
"safetensors": "safetensors>=0.3.1",
|
34
34
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
35
35
|
"scipy": "scipy",
|
diffusers/loaders.py
CHANGED
@@ -11,6 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
import copy
|
14
15
|
import os
|
15
16
|
import re
|
16
17
|
import warnings
|
@@ -21,6 +22,7 @@ from pathlib import Path
|
|
21
22
|
from typing import Callable, Dict, List, Optional, Union
|
22
23
|
|
23
24
|
import requests
|
25
|
+
import safetensors
|
24
26
|
import torch
|
25
27
|
import torch.nn.functional as F
|
26
28
|
from huggingface_hub import hf_hub_download
|
@@ -33,16 +35,12 @@ from .utils import (
|
|
33
35
|
deprecate,
|
34
36
|
is_accelerate_available,
|
35
37
|
is_omegaconf_available,
|
36
|
-
is_safetensors_available,
|
37
38
|
is_transformers_available,
|
38
39
|
logging,
|
39
40
|
)
|
40
41
|
from .utils.import_utils import BACKENDS_MAPPING
|
41
42
|
|
42
43
|
|
43
|
-
if is_safetensors_available():
|
44
|
-
import safetensors
|
45
|
-
|
46
44
|
if is_transformers_available():
|
47
45
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
48
46
|
|
@@ -189,7 +187,7 @@ class UNet2DConditionLoadersMixin:
|
|
189
187
|
r"""
|
190
188
|
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
191
189
|
defined in
|
192
|
-
[`
|
190
|
+
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
|
193
191
|
and be a `torch.nn.Module` class.
|
194
192
|
|
195
193
|
Parameters:
|
@@ -258,15 +256,12 @@ class UNet2DConditionLoadersMixin:
|
|
258
256
|
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
259
257
|
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
260
258
|
network_alphas = kwargs.pop("network_alphas", None)
|
261
|
-
|
262
|
-
if use_safetensors and not is_safetensors_available():
|
263
|
-
raise ValueError(
|
264
|
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
265
|
-
)
|
259
|
+
is_network_alphas_none = network_alphas is None
|
266
260
|
|
267
261
|
allow_pickle = False
|
262
|
+
|
268
263
|
if use_safetensors is None:
|
269
|
-
use_safetensors =
|
264
|
+
use_safetensors = True
|
270
265
|
allow_pickle = True
|
271
266
|
|
272
267
|
user_agent = {
|
@@ -349,13 +344,20 @@ class UNet2DConditionLoadersMixin:
|
|
349
344
|
|
350
345
|
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
351
346
|
if network_alphas is not None:
|
352
|
-
|
347
|
+
network_alphas_ = copy.deepcopy(network_alphas)
|
348
|
+
for k in network_alphas_:
|
353
349
|
if k.replace(".alpha", "") in key:
|
354
|
-
mapped_network_alphas.update({attn_processor_key: network_alphas
|
350
|
+
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
|
351
|
+
|
352
|
+
if not is_network_alphas_none:
|
353
|
+
if len(network_alphas) > 0:
|
354
|
+
raise ValueError(
|
355
|
+
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
356
|
+
)
|
355
357
|
|
356
358
|
if len(state_dict) > 0:
|
357
359
|
raise ValueError(
|
358
|
-
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
360
|
+
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
359
361
|
)
|
360
362
|
|
361
363
|
for key, value_dict in lora_grouped_dict.items():
|
@@ -434,14 +436,6 @@ class UNet2DConditionLoadersMixin:
|
|
434
436
|
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
|
435
437
|
out_rank=rank_mapping.get("to_out_lora.down.weight"),
|
436
438
|
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
|
437
|
-
# rank=rank_mapping.get("to_k_lora.down.weight", None),
|
438
|
-
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
|
439
|
-
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
|
440
|
-
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
|
441
|
-
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
|
442
|
-
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
|
443
|
-
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
|
444
|
-
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
|
445
439
|
)
|
446
440
|
else:
|
447
441
|
attn_processors[key] = attn_processor_class(
|
@@ -496,9 +490,6 @@ class UNet2DConditionLoadersMixin:
|
|
496
490
|
# set ff layers
|
497
491
|
for target_module, lora_layer in non_attn_lora_layers:
|
498
492
|
target_module.set_lora_layer(lora_layer)
|
499
|
-
# It should raise an error if we don't have a set lora here
|
500
|
-
# if hasattr(target_module, "set_lora_layer"):
|
501
|
-
# target_module.set_lora_layer(lora_layer)
|
502
493
|
|
503
494
|
def save_attn_procs(
|
504
495
|
self,
|
@@ -506,7 +497,7 @@ class UNet2DConditionLoadersMixin:
|
|
506
497
|
is_main_process: bool = True,
|
507
498
|
weight_name: str = None,
|
508
499
|
save_function: Callable = None,
|
509
|
-
safe_serialization: bool =
|
500
|
+
safe_serialization: bool = True,
|
510
501
|
**kwargs,
|
511
502
|
):
|
512
503
|
r"""
|
@@ -524,19 +515,14 @@ class UNet2DConditionLoadersMixin:
|
|
524
515
|
The function to use to save the state dictionary. Useful during distributed training when you need to
|
525
516
|
replace `torch.save` with another method. Can be configured with the environment variable
|
526
517
|
`DIFFUSERS_SAVE_MODE`.
|
527
|
-
|
518
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
519
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
528
520
|
"""
|
529
521
|
from .models.attention_processor import (
|
530
522
|
CustomDiffusionAttnProcessor,
|
531
523
|
CustomDiffusionXFormersAttnProcessor,
|
532
524
|
)
|
533
525
|
|
534
|
-
weight_name = weight_name or deprecate(
|
535
|
-
"weights_name",
|
536
|
-
"0.20.0",
|
537
|
-
"`weights_name` is deprecated, please use `weight_name` instead.",
|
538
|
-
take_from=kwargs,
|
539
|
-
)
|
540
526
|
if os.path.isfile(save_directory):
|
541
527
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
542
528
|
return
|
@@ -766,14 +752,9 @@ class TextualInversionLoaderMixin:
|
|
766
752
|
weight_name = kwargs.pop("weight_name", None)
|
767
753
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
768
754
|
|
769
|
-
if use_safetensors and not is_safetensors_available():
|
770
|
-
raise ValueError(
|
771
|
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
772
|
-
)
|
773
|
-
|
774
755
|
allow_pickle = False
|
775
756
|
if use_safetensors is None:
|
776
|
-
use_safetensors =
|
757
|
+
use_safetensors = True
|
777
758
|
allow_pickle = True
|
778
759
|
|
779
760
|
user_agent = {
|
@@ -1023,14 +1004,9 @@ class LoraLoaderMixin:
|
|
1023
1004
|
unet_config = kwargs.pop("unet_config", None)
|
1024
1005
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
1025
1006
|
|
1026
|
-
if use_safetensors and not is_safetensors_available():
|
1027
|
-
raise ValueError(
|
1028
|
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
1029
|
-
)
|
1030
|
-
|
1031
1007
|
allow_pickle = False
|
1032
1008
|
if use_safetensors is None:
|
1033
|
-
use_safetensors =
|
1009
|
+
use_safetensors = True
|
1034
1010
|
allow_pickle = True
|
1035
1011
|
|
1036
1012
|
user_agent = {
|
@@ -1063,6 +1039,7 @@ class LoraLoaderMixin:
|
|
1063
1039
|
if not allow_pickle:
|
1064
1040
|
raise e
|
1065
1041
|
# try loading non-safetensors weights
|
1042
|
+
model_file = None
|
1066
1043
|
pass
|
1067
1044
|
if model_file is None:
|
1068
1045
|
model_file = _get_model_file(
|
@@ -1258,9 +1235,10 @@ class LoraLoaderMixin:
|
|
1258
1235
|
keys = list(state_dict.keys())
|
1259
1236
|
prefix = cls.text_encoder_name if prefix is None else prefix
|
1260
1237
|
|
1238
|
+
# Safe prefix to check with.
|
1261
1239
|
if any(cls.text_encoder_name in key for key in keys):
|
1262
1240
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
1263
|
-
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
|
1241
|
+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
1264
1242
|
text_encoder_lora_state_dict = {
|
1265
1243
|
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
1266
1244
|
}
|
@@ -1310,6 +1288,14 @@ class LoraLoaderMixin:
|
|
1310
1288
|
].shape[1]
|
1311
1289
|
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
1312
1290
|
|
1291
|
+
if network_alphas is not None:
|
1292
|
+
alpha_keys = [
|
1293
|
+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
1294
|
+
]
|
1295
|
+
network_alphas = {
|
1296
|
+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
1297
|
+
}
|
1298
|
+
|
1313
1299
|
cls._modify_text_encoder(
|
1314
1300
|
text_encoder,
|
1315
1301
|
lora_scale,
|
@@ -1371,12 +1357,13 @@ class LoraLoaderMixin:
|
|
1371
1357
|
|
1372
1358
|
lora_parameters = []
|
1373
1359
|
network_alphas = {} if network_alphas is None else network_alphas
|
1360
|
+
is_network_alphas_populated = len(network_alphas) > 0
|
1374
1361
|
|
1375
1362
|
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
1376
|
-
query_alpha = network_alphas.
|
1377
|
-
key_alpha = network_alphas.
|
1378
|
-
value_alpha = network_alphas.
|
1379
|
-
|
1363
|
+
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
|
1364
|
+
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
|
1365
|
+
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
1366
|
+
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
1380
1367
|
|
1381
1368
|
attn_module.q_proj = PatchedLoraProjection(
|
1382
1369
|
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
|
@@ -1394,14 +1381,14 @@ class LoraLoaderMixin:
|
|
1394
1381
|
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
1395
1382
|
|
1396
1383
|
attn_module.out_proj = PatchedLoraProjection(
|
1397
|
-
attn_module.out_proj, lora_scale, network_alpha=
|
1384
|
+
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
|
1398
1385
|
)
|
1399
1386
|
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
1400
1387
|
|
1401
1388
|
if patch_mlp:
|
1402
1389
|
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
1403
|
-
fc1_alpha = network_alphas.
|
1404
|
-
fc2_alpha = network_alphas.
|
1390
|
+
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
|
1391
|
+
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
|
1405
1392
|
|
1406
1393
|
mlp_module.fc1 = PatchedLoraProjection(
|
1407
1394
|
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
|
@@ -1413,6 +1400,11 @@ class LoraLoaderMixin:
|
|
1413
1400
|
)
|
1414
1401
|
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
1415
1402
|
|
1403
|
+
if is_network_alphas_populated and len(network_alphas) > 0:
|
1404
|
+
raise ValueError(
|
1405
|
+
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
1406
|
+
)
|
1407
|
+
|
1416
1408
|
return lora_parameters
|
1417
1409
|
|
1418
1410
|
@classmethod
|
@@ -1424,7 +1416,7 @@ class LoraLoaderMixin:
|
|
1424
1416
|
is_main_process: bool = True,
|
1425
1417
|
weight_name: str = None,
|
1426
1418
|
save_function: Callable = None,
|
1427
|
-
safe_serialization: bool =
|
1419
|
+
safe_serialization: bool = True,
|
1428
1420
|
):
|
1429
1421
|
r"""
|
1430
1422
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
@@ -1445,6 +1437,8 @@ class LoraLoaderMixin:
|
|
1445
1437
|
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1446
1438
|
replace `torch.save` with another method. Can be configured with the environment variable
|
1447
1439
|
`DIFFUSERS_SAVE_MODE`.
|
1440
|
+
safe_serialization (`bool`, *optional*, defaults to `True`):
|
1441
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
1448
1442
|
"""
|
1449
1443
|
# Create a flat dictionary.
|
1450
1444
|
state_dict = {}
|
@@ -1526,10 +1520,6 @@ class LoraLoaderMixin:
|
|
1526
1520
|
lora_name_up = lora_name + ".lora_up.weight"
|
1527
1521
|
lora_name_alpha = lora_name + ".alpha"
|
1528
1522
|
|
1529
|
-
# if lora_name_alpha in state_dict:
|
1530
|
-
# alpha = state_dict.pop(lora_name_alpha).item()
|
1531
|
-
# network_alphas.update({lora_name_alpha: alpha})
|
1532
|
-
|
1533
1523
|
if lora_name.startswith("lora_unet_"):
|
1534
1524
|
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
1535
1525
|
|
@@ -1851,7 +1841,7 @@ class FromSingleFileMixin:
|
|
1851
1841
|
|
1852
1842
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
1853
1843
|
|
1854
|
-
use_safetensors = kwargs.pop("use_safetensors", None
|
1844
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1855
1845
|
|
1856
1846
|
pipeline_name = cls.__name__
|
1857
1847
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
@@ -1892,16 +1882,24 @@ class FromSingleFileMixin:
|
|
1892
1882
|
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
1893
1883
|
|
1894
1884
|
# remove huggingface url
|
1895
|
-
|
1885
|
+
has_valid_url_prefix = False
|
1886
|
+
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
1887
|
+
for prefix in valid_url_prefixes:
|
1896
1888
|
if pretrained_model_link_or_path.startswith(prefix):
|
1897
1889
|
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
1890
|
+
has_valid_url_prefix = True
|
1898
1891
|
|
1899
1892
|
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
1900
1893
|
ckpt_path = Path(pretrained_model_link_or_path)
|
1901
1894
|
if not ckpt_path.is_file():
|
1895
|
+
if not has_valid_url_prefix:
|
1896
|
+
raise ValueError(
|
1897
|
+
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
1898
|
+
)
|
1899
|
+
|
1902
1900
|
# get repo_id and (potentially nested) file path of ckpt in repo
|
1903
|
-
repo_id =
|
1904
|
-
file_path =
|
1901
|
+
repo_id = "/".join(ckpt_path.parts[:2])
|
1902
|
+
file_path = "/".join(ckpt_path.parts[2:])
|
1905
1903
|
|
1906
1904
|
if file_path.startswith("blob/"):
|
1907
1905
|
file_path = file_path[len("blob/") :]
|
@@ -2048,7 +2046,7 @@ class FromOriginalVAEMixin:
|
|
2048
2046
|
|
2049
2047
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
2050
2048
|
|
2051
|
-
use_safetensors = kwargs.pop("use_safetensors", None
|
2049
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2052
2050
|
|
2053
2051
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
2054
2052
|
from_safetensors = file_extension == "safetensors"
|
@@ -2221,7 +2219,7 @@ class FromOriginalControlnetMixin:
|
|
2221
2219
|
|
2222
2220
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
2223
2221
|
|
2224
|
-
use_safetensors = kwargs.pop("use_safetensors", None
|
2222
|
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
2225
2223
|
|
2226
2224
|
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
2227
2225
|
from_safetensors = file_extension == "safetensors"
|
diffusers/models/__init__.py
CHANGED
@@ -19,6 +19,7 @@ if is_torch_available():
|
|
19
19
|
from .adapter import MultiAdapter, T2IAdapter
|
20
20
|
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
21
21
|
from .autoencoder_kl import AutoencoderKL
|
22
|
+
from .autoencoder_tiny import AutoencoderTiny
|
22
23
|
from .controlnet import ControlNetModel
|
23
24
|
from .dual_transformer_2d import DualTransformer2DModel
|
24
25
|
from .modeling_utils import ModelMixin
|
diffusers/models/activations.py
CHANGED
diffusers/models/attention.py
CHANGED
@@ -24,6 +24,38 @@ from .embeddings import CombinedTimestepLabelEmbeddings
|
|
24
24
|
from .lora import LoRACompatibleLinear
|
25
25
|
|
26
26
|
|
27
|
+
@maybe_allow_in_graph
|
28
|
+
class GatedSelfAttentionDense(nn.Module):
|
29
|
+
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
30
|
+
super().__init__()
|
31
|
+
|
32
|
+
# we need a linear projection since we need cat visual feature and obj feature
|
33
|
+
self.linear = nn.Linear(context_dim, query_dim)
|
34
|
+
|
35
|
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
36
|
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
37
|
+
|
38
|
+
self.norm1 = nn.LayerNorm(query_dim)
|
39
|
+
self.norm2 = nn.LayerNorm(query_dim)
|
40
|
+
|
41
|
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
42
|
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
43
|
+
|
44
|
+
self.enabled = True
|
45
|
+
|
46
|
+
def forward(self, x, objs):
|
47
|
+
if not self.enabled:
|
48
|
+
return x
|
49
|
+
|
50
|
+
n_visual = x.shape[1]
|
51
|
+
objs = self.linear(objs)
|
52
|
+
|
53
|
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
54
|
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
55
|
+
|
56
|
+
return x
|
57
|
+
|
58
|
+
|
27
59
|
@maybe_allow_in_graph
|
28
60
|
class BasicTransformerBlock(nn.Module):
|
29
61
|
r"""
|
@@ -62,6 +94,7 @@ class BasicTransformerBlock(nn.Module):
|
|
62
94
|
norm_elementwise_affine: bool = True,
|
63
95
|
norm_type: str = "layer_norm",
|
64
96
|
final_dropout: bool = False,
|
97
|
+
attention_type: str = "default",
|
65
98
|
):
|
66
99
|
super().__init__()
|
67
100
|
self.only_cross_attention = only_cross_attention
|
@@ -120,6 +153,10 @@ class BasicTransformerBlock(nn.Module):
|
|
120
153
|
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
121
154
|
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
122
155
|
|
156
|
+
# 4. Fuser
|
157
|
+
if attention_type == "gated":
|
158
|
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
159
|
+
|
123
160
|
# let chunk size default to None
|
124
161
|
self._chunk_size = None
|
125
162
|
self._chunk_dim = 0
|
@@ -150,7 +187,9 @@ class BasicTransformerBlock(nn.Module):
|
|
150
187
|
else:
|
151
188
|
norm_hidden_states = self.norm1(hidden_states)
|
152
189
|
|
153
|
-
|
190
|
+
# 0. Prepare GLIGEN inputs
|
191
|
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
192
|
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
154
193
|
|
155
194
|
attn_output = self.attn1(
|
156
195
|
norm_hidden_states,
|
@@ -162,6 +201,11 @@ class BasicTransformerBlock(nn.Module):
|
|
162
201
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
163
202
|
hidden_states = attn_output + hidden_states
|
164
203
|
|
204
|
+
# 1.5 GLIGEN Control
|
205
|
+
if gligen_kwargs is not None:
|
206
|
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
207
|
+
# 1.5 ends
|
208
|
+
|
165
209
|
# 2. Cross-Attention
|
166
210
|
if self.attn2 is not None:
|
167
211
|
norm_hidden_states = (
|