diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. diffusers/__init__.py +3 -1
  2. diffusers/commands/fp16_safetensors.py +2 -7
  3. diffusers/configuration_utils.py +23 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/loaders.py +62 -64
  6. diffusers/models/__init__.py +1 -0
  7. diffusers/models/activations.py +2 -0
  8. diffusers/models/attention.py +45 -1
  9. diffusers/models/autoencoder_tiny.py +193 -0
  10. diffusers/models/controlnet.py +1 -1
  11. diffusers/models/embeddings.py +56 -0
  12. diffusers/models/lora.py +0 -6
  13. diffusers/models/modeling_flax_utils.py +28 -2
  14. diffusers/models/modeling_utils.py +33 -16
  15. diffusers/models/transformer_2d.py +26 -9
  16. diffusers/models/unet_1d.py +2 -2
  17. diffusers/models/unet_2d_blocks.py +106 -56
  18. diffusers/models/unet_2d_condition.py +20 -5
  19. diffusers/models/vae.py +106 -1
  20. diffusers/pipelines/__init__.py +1 -0
  21. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
  22. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
  23. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  24. diffusers/pipelines/auto_pipeline.py +33 -43
  25. diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
  26. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
  27. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
  28. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
  29. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
  30. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
  31. diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
  32. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
  33. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
  34. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
  35. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
  36. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
  37. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
  38. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
  39. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  40. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  41. diffusers/pipelines/pipeline_flax_utils.py +41 -4
  42. diffusers/pipelines/pipeline_utils.py +60 -16
  43. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
  44. diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  45. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
  46. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
  47. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
  48. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
  49. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
  50. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
  51. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
  52. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
  53. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
  54. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
  55. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
  56. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
  57. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
  58. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
  59. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
  60. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
  61. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
  65. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
  66. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
  67. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
  68. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
  69. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
  70. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
  71. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
  72. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
  73. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
  74. diffusers/schedulers/scheduling_consistency_models.py +70 -57
  75. diffusers/schedulers/scheduling_ddim.py +76 -71
  76. diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
  77. diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
  78. diffusers/schedulers/scheduling_ddpm.py +68 -67
  79. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
  80. diffusers/schedulers/scheduling_deis_multistep.py +93 -85
  81. diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
  82. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
  83. diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
  84. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
  85. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
  86. diffusers/schedulers/scheduling_euler_discrete.py +63 -56
  87. diffusers/schedulers/scheduling_heun_discrete.py +57 -45
  88. diffusers/schedulers/scheduling_ipndm.py +27 -22
  89. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
  90. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
  91. diffusers/schedulers/scheduling_karras_ve.py +55 -45
  92. diffusers/schedulers/scheduling_lms_discrete.py +58 -52
  93. diffusers/schedulers/scheduling_pndm.py +77 -62
  94. diffusers/schedulers/scheduling_repaint.py +56 -38
  95. diffusers/schedulers/scheduling_sde_ve.py +62 -50
  96. diffusers/schedulers/scheduling_sde_vp.py +32 -11
  97. diffusers/schedulers/scheduling_unclip.py +3 -3
  98. diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
  99. diffusers/schedulers/scheduling_utils.py +41 -35
  100. diffusers/schedulers/scheduling_utils_flax.py +8 -2
  101. diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
  102. diffusers/utils/__init__.py +2 -2
  103. diffusers/utils/dummy_pt_objects.py +15 -0
  104. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  105. diffusers/utils/hub_utils.py +105 -2
  106. diffusers/utils/import_utils.py +0 -4
  107. diffusers/utils/pil_utils.py +19 -0
  108. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
  109. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
  110. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
  111. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
  112. diffusers/models/cross_attention.py +0 -94
  113. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
  114. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.19.3"
1
+ __version__ = "0.20.1"
2
2
 
3
3
  from .configuration_utils import ConfigMixin
4
4
  from .utils import (
@@ -38,6 +38,7 @@ else:
38
38
  from .models import (
39
39
  AsymmetricAutoencoderKL,
40
40
  AutoencoderKL,
41
+ AutoencoderTiny,
41
42
  ControlNetModel,
42
43
  ModelMixin,
43
44
  MultiAdapter,
@@ -170,6 +171,7 @@ else:
170
171
  StableDiffusionControlNetPipeline,
171
172
  StableDiffusionDepth2ImgPipeline,
172
173
  StableDiffusionDiffEditPipeline,
174
+ StableDiffusionGLIGENPipeline,
173
175
  StableDiffusionImageVariationPipeline,
174
176
  StableDiffusionImg2ImgPipeline,
175
177
  StableDiffusionInpaintPipeline,
@@ -27,7 +27,7 @@ import torch
27
27
  from huggingface_hub import hf_hub_download
28
28
  from packaging import version
29
29
 
30
- from ..utils import is_safetensors_available, logging
30
+ from ..utils import logging
31
31
  from . import BaseDiffusersCLICommand
32
32
 
33
33
 
@@ -68,12 +68,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
68
68
  self.local_ckpt_dir = f"/tmp/{ckpt_id}"
69
69
  self.fp16 = fp16
70
70
 
71
- if is_safetensors_available():
72
- self.use_safetensors = use_safetensors
73
- else:
74
- raise ImportError(
75
- "When `use_safetensors` is set to True, the `safetensors` library needs to be installed. Install it via `pip install safetensors`."
76
- )
71
+ self.use_safetensors = use_safetensors
77
72
 
78
73
  if not self.use_safetensors and not self.fp16:
79
74
  raise NotImplementedError(
@@ -26,7 +26,7 @@ from pathlib import PosixPath
26
26
  from typing import Any, Dict, Tuple, Union
27
27
 
28
28
  import numpy as np
29
- from huggingface_hub import hf_hub_download
29
+ from huggingface_hub import create_repo, hf_hub_download
30
30
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
31
31
  from requests import HTTPError
32
32
 
@@ -144,6 +144,12 @@ class ConfigMixin:
144
144
  Args:
145
145
  save_directory (`str` or `os.PathLike`):
146
146
  Directory where the configuration JSON file is saved (will be created if it does not exist).
147
+ push_to_hub (`bool`, *optional*, defaults to `False`):
148
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
149
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
150
+ namespace).
151
+ kwargs (`Dict[str, Any]`, *optional*):
152
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
147
153
  """
148
154
  if os.path.isfile(save_directory):
149
155
  raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -156,6 +162,22 @@ class ConfigMixin:
156
162
  self.to_json_file(output_config_file)
157
163
  logger.info(f"Configuration saved in {output_config_file}")
158
164
 
165
+ if push_to_hub:
166
+ commit_message = kwargs.pop("commit_message", None)
167
+ private = kwargs.pop("private", False)
168
+ create_pr = kwargs.pop("create_pr", False)
169
+ token = kwargs.pop("token", None)
170
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
171
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
172
+
173
+ self._upload_folder(
174
+ save_directory,
175
+ repo_id,
176
+ token=token,
177
+ commit_message=commit_message,
178
+ create_pr=create_pr,
179
+ )
180
+
159
181
  @classmethod
160
182
  def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
161
183
  r"""
@@ -29,7 +29,7 @@ deps = {
29
29
  "pytest": "pytest",
30
30
  "pytest-timeout": "pytest-timeout",
31
31
  "pytest-xdist": "pytest-xdist",
32
- "ruff": "ruff>=0.0.241",
32
+ "ruff": "ruff==0.0.280",
33
33
  "safetensors": "safetensors>=0.3.1",
34
34
  "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
35
35
  "scipy": "scipy",
diffusers/loaders.py CHANGED
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import copy
14
15
  import os
15
16
  import re
16
17
  import warnings
@@ -21,6 +22,7 @@ from pathlib import Path
21
22
  from typing import Callable, Dict, List, Optional, Union
22
23
 
23
24
  import requests
25
+ import safetensors
24
26
  import torch
25
27
  import torch.nn.functional as F
26
28
  from huggingface_hub import hf_hub_download
@@ -33,16 +35,12 @@ from .utils import (
33
35
  deprecate,
34
36
  is_accelerate_available,
35
37
  is_omegaconf_available,
36
- is_safetensors_available,
37
38
  is_transformers_available,
38
39
  logging,
39
40
  )
40
41
  from .utils.import_utils import BACKENDS_MAPPING
41
42
 
42
43
 
43
- if is_safetensors_available():
44
- import safetensors
45
-
46
44
  if is_transformers_available():
47
45
  from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
48
46
 
@@ -189,7 +187,7 @@ class UNet2DConditionLoadersMixin:
189
187
  r"""
190
188
  Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
191
189
  defined in
192
- [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
190
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
193
191
  and be a `torch.nn.Module` class.
194
192
 
195
193
  Parameters:
@@ -258,15 +256,12 @@ class UNet2DConditionLoadersMixin:
258
256
  # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
259
257
  # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
260
258
  network_alphas = kwargs.pop("network_alphas", None)
261
-
262
- if use_safetensors and not is_safetensors_available():
263
- raise ValueError(
264
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
265
- )
259
+ is_network_alphas_none = network_alphas is None
266
260
 
267
261
  allow_pickle = False
262
+
268
263
  if use_safetensors is None:
269
- use_safetensors = is_safetensors_available()
264
+ use_safetensors = True
270
265
  allow_pickle = True
271
266
 
272
267
  user_agent = {
@@ -349,13 +344,20 @@ class UNet2DConditionLoadersMixin:
349
344
 
350
345
  # Create another `mapped_network_alphas` dictionary so that we can properly map them.
351
346
  if network_alphas is not None:
352
- for k in network_alphas:
347
+ network_alphas_ = copy.deepcopy(network_alphas)
348
+ for k in network_alphas_:
353
349
  if k.replace(".alpha", "") in key:
354
- mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
350
+ mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
351
+
352
+ if not is_network_alphas_none:
353
+ if len(network_alphas) > 0:
354
+ raise ValueError(
355
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
356
+ )
355
357
 
356
358
  if len(state_dict) > 0:
357
359
  raise ValueError(
358
- f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
360
+ f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
359
361
  )
360
362
 
361
363
  for key, value_dict in lora_grouped_dict.items():
@@ -434,14 +436,6 @@ class UNet2DConditionLoadersMixin:
434
436
  v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
435
437
  out_rank=rank_mapping.get("to_out_lora.down.weight"),
436
438
  out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
437
- # rank=rank_mapping.get("to_k_lora.down.weight", None),
438
- # hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
439
- # q_rank=rank_mapping.get("to_q_lora.down.weight", None),
440
- # q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
441
- # v_rank=rank_mapping.get("to_v_lora.down.weight", None),
442
- # v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
443
- # out_rank=rank_mapping.get("to_out_lora.down.weight", None),
444
- # out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
445
439
  )
446
440
  else:
447
441
  attn_processors[key] = attn_processor_class(
@@ -496,9 +490,6 @@ class UNet2DConditionLoadersMixin:
496
490
  # set ff layers
497
491
  for target_module, lora_layer in non_attn_lora_layers:
498
492
  target_module.set_lora_layer(lora_layer)
499
- # It should raise an error if we don't have a set lora here
500
- # if hasattr(target_module, "set_lora_layer"):
501
- # target_module.set_lora_layer(lora_layer)
502
493
 
503
494
  def save_attn_procs(
504
495
  self,
@@ -506,7 +497,7 @@ class UNet2DConditionLoadersMixin:
506
497
  is_main_process: bool = True,
507
498
  weight_name: str = None,
508
499
  save_function: Callable = None,
509
- safe_serialization: bool = False,
500
+ safe_serialization: bool = True,
510
501
  **kwargs,
511
502
  ):
512
503
  r"""
@@ -524,19 +515,14 @@ class UNet2DConditionLoadersMixin:
524
515
  The function to use to save the state dictionary. Useful during distributed training when you need to
525
516
  replace `torch.save` with another method. Can be configured with the environment variable
526
517
  `DIFFUSERS_SAVE_MODE`.
527
-
518
+ safe_serialization (`bool`, *optional*, defaults to `True`):
519
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
528
520
  """
529
521
  from .models.attention_processor import (
530
522
  CustomDiffusionAttnProcessor,
531
523
  CustomDiffusionXFormersAttnProcessor,
532
524
  )
533
525
 
534
- weight_name = weight_name or deprecate(
535
- "weights_name",
536
- "0.20.0",
537
- "`weights_name` is deprecated, please use `weight_name` instead.",
538
- take_from=kwargs,
539
- )
540
526
  if os.path.isfile(save_directory):
541
527
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
542
528
  return
@@ -766,14 +752,9 @@ class TextualInversionLoaderMixin:
766
752
  weight_name = kwargs.pop("weight_name", None)
767
753
  use_safetensors = kwargs.pop("use_safetensors", None)
768
754
 
769
- if use_safetensors and not is_safetensors_available():
770
- raise ValueError(
771
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
772
- )
773
-
774
755
  allow_pickle = False
775
756
  if use_safetensors is None:
776
- use_safetensors = is_safetensors_available()
757
+ use_safetensors = True
777
758
  allow_pickle = True
778
759
 
779
760
  user_agent = {
@@ -1023,14 +1004,9 @@ class LoraLoaderMixin:
1023
1004
  unet_config = kwargs.pop("unet_config", None)
1024
1005
  use_safetensors = kwargs.pop("use_safetensors", None)
1025
1006
 
1026
- if use_safetensors and not is_safetensors_available():
1027
- raise ValueError(
1028
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1029
- )
1030
-
1031
1007
  allow_pickle = False
1032
1008
  if use_safetensors is None:
1033
- use_safetensors = is_safetensors_available()
1009
+ use_safetensors = True
1034
1010
  allow_pickle = True
1035
1011
 
1036
1012
  user_agent = {
@@ -1063,6 +1039,7 @@ class LoraLoaderMixin:
1063
1039
  if not allow_pickle:
1064
1040
  raise e
1065
1041
  # try loading non-safetensors weights
1042
+ model_file = None
1066
1043
  pass
1067
1044
  if model_file is None:
1068
1045
  model_file = _get_model_file(
@@ -1258,9 +1235,10 @@ class LoraLoaderMixin:
1258
1235
  keys = list(state_dict.keys())
1259
1236
  prefix = cls.text_encoder_name if prefix is None else prefix
1260
1237
 
1238
+ # Safe prefix to check with.
1261
1239
  if any(cls.text_encoder_name in key for key in keys):
1262
1240
  # Load the layers corresponding to text encoder and make necessary adjustments.
1263
- text_encoder_keys = [k for k in keys if k.startswith(prefix)]
1241
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
1264
1242
  text_encoder_lora_state_dict = {
1265
1243
  k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
1266
1244
  }
@@ -1310,6 +1288,14 @@ class LoraLoaderMixin:
1310
1288
  ].shape[1]
1311
1289
  patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
1312
1290
 
1291
+ if network_alphas is not None:
1292
+ alpha_keys = [
1293
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1294
+ ]
1295
+ network_alphas = {
1296
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1297
+ }
1298
+
1313
1299
  cls._modify_text_encoder(
1314
1300
  text_encoder,
1315
1301
  lora_scale,
@@ -1371,12 +1357,13 @@ class LoraLoaderMixin:
1371
1357
 
1372
1358
  lora_parameters = []
1373
1359
  network_alphas = {} if network_alphas is None else network_alphas
1360
+ is_network_alphas_populated = len(network_alphas) > 0
1374
1361
 
1375
1362
  for name, attn_module in text_encoder_attn_modules(text_encoder):
1376
- query_alpha = network_alphas.get(name + ".k.proj.alpha")
1377
- key_alpha = network_alphas.get(name + ".q.proj.alpha")
1378
- value_alpha = network_alphas.get(name + ".v.proj.alpha")
1379
- proj_alpha = network_alphas.get(name + ".out.proj.alpha")
1363
+ query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
1364
+ key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
1365
+ value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
1366
+ out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
1380
1367
 
1381
1368
  attn_module.q_proj = PatchedLoraProjection(
1382
1369
  attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
@@ -1394,14 +1381,14 @@ class LoraLoaderMixin:
1394
1381
  lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
1395
1382
 
1396
1383
  attn_module.out_proj = PatchedLoraProjection(
1397
- attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
1384
+ attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=rank, dtype=dtype
1398
1385
  )
1399
1386
  lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
1400
1387
 
1401
1388
  if patch_mlp:
1402
1389
  for name, mlp_module in text_encoder_mlp_modules(text_encoder):
1403
- fc1_alpha = network_alphas.get(name + ".fc1.alpha")
1404
- fc2_alpha = network_alphas.get(name + ".fc2.alpha")
1390
+ fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha")
1391
+ fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha")
1405
1392
 
1406
1393
  mlp_module.fc1 = PatchedLoraProjection(
1407
1394
  mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
@@ -1413,6 +1400,11 @@ class LoraLoaderMixin:
1413
1400
  )
1414
1401
  lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
1415
1402
 
1403
+ if is_network_alphas_populated and len(network_alphas) > 0:
1404
+ raise ValueError(
1405
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
1406
+ )
1407
+
1416
1408
  return lora_parameters
1417
1409
 
1418
1410
  @classmethod
@@ -1424,7 +1416,7 @@ class LoraLoaderMixin:
1424
1416
  is_main_process: bool = True,
1425
1417
  weight_name: str = None,
1426
1418
  save_function: Callable = None,
1427
- safe_serialization: bool = False,
1419
+ safe_serialization: bool = True,
1428
1420
  ):
1429
1421
  r"""
1430
1422
  Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1445,6 +1437,8 @@ class LoraLoaderMixin:
1445
1437
  The function to use to save the state dictionary. Useful during distributed training when you need to
1446
1438
  replace `torch.save` with another method. Can be configured with the environment variable
1447
1439
  `DIFFUSERS_SAVE_MODE`.
1440
+ safe_serialization (`bool`, *optional*, defaults to `True`):
1441
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
1448
1442
  """
1449
1443
  # Create a flat dictionary.
1450
1444
  state_dict = {}
@@ -1526,10 +1520,6 @@ class LoraLoaderMixin:
1526
1520
  lora_name_up = lora_name + ".lora_up.weight"
1527
1521
  lora_name_alpha = lora_name + ".alpha"
1528
1522
 
1529
- # if lora_name_alpha in state_dict:
1530
- # alpha = state_dict.pop(lora_name_alpha).item()
1531
- # network_alphas.update({lora_name_alpha: alpha})
1532
-
1533
1523
  if lora_name.startswith("lora_unet_"):
1534
1524
  diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1535
1525
 
@@ -1851,7 +1841,7 @@ class FromSingleFileMixin:
1851
1841
 
1852
1842
  torch_dtype = kwargs.pop("torch_dtype", None)
1853
1843
 
1854
- use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1844
+ use_safetensors = kwargs.pop("use_safetensors", None)
1855
1845
 
1856
1846
  pipeline_name = cls.__name__
1857
1847
  file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
@@ -1892,16 +1882,24 @@ class FromSingleFileMixin:
1892
1882
  raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1893
1883
 
1894
1884
  # remove huggingface url
1895
- for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1885
+ has_valid_url_prefix = False
1886
+ valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
1887
+ for prefix in valid_url_prefixes:
1896
1888
  if pretrained_model_link_or_path.startswith(prefix):
1897
1889
  pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1890
+ has_valid_url_prefix = True
1898
1891
 
1899
1892
  # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1900
1893
  ckpt_path = Path(pretrained_model_link_or_path)
1901
1894
  if not ckpt_path.is_file():
1895
+ if not has_valid_url_prefix:
1896
+ raise ValueError(
1897
+ f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
1898
+ )
1899
+
1902
1900
  # get repo_id and (potentially nested) file path of ckpt in repo
1903
- repo_id = os.path.join(*ckpt_path.parts[:2])
1904
- file_path = os.path.join(*ckpt_path.parts[2:])
1901
+ repo_id = "/".join(ckpt_path.parts[:2])
1902
+ file_path = "/".join(ckpt_path.parts[2:])
1905
1903
 
1906
1904
  if file_path.startswith("blob/"):
1907
1905
  file_path = file_path[len("blob/") :]
@@ -2048,7 +2046,7 @@ class FromOriginalVAEMixin:
2048
2046
 
2049
2047
  torch_dtype = kwargs.pop("torch_dtype", None)
2050
2048
 
2051
- use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
2049
+ use_safetensors = kwargs.pop("use_safetensors", None)
2052
2050
 
2053
2051
  file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
2054
2052
  from_safetensors = file_extension == "safetensors"
@@ -2221,7 +2219,7 @@ class FromOriginalControlnetMixin:
2221
2219
 
2222
2220
  torch_dtype = kwargs.pop("torch_dtype", None)
2223
2221
 
2224
- use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
2222
+ use_safetensors = kwargs.pop("use_safetensors", None)
2225
2223
 
2226
2224
  file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
2227
2225
  from_safetensors = file_extension == "safetensors"
@@ -19,6 +19,7 @@ if is_torch_available():
19
19
  from .adapter import MultiAdapter, T2IAdapter
20
20
  from .autoencoder_asym_kl import AsymmetricAutoencoderKL
21
21
  from .autoencoder_kl import AutoencoderKL
22
+ from .autoencoder_tiny import AutoencoderTiny
22
23
  from .controlnet import ControlNetModel
23
24
  from .dual_transformer_2d import DualTransformer2DModel
24
25
  from .modeling_utils import ModelMixin
@@ -8,5 +8,7 @@ def get_activation(act_fn):
8
8
  return nn.Mish()
9
9
  elif act_fn == "gelu":
10
10
  return nn.GELU()
11
+ elif act_fn == "relu":
12
+ return nn.ReLU()
11
13
  else:
12
14
  raise ValueError(f"Unsupported activation function: {act_fn}")
@@ -24,6 +24,38 @@ from .embeddings import CombinedTimestepLabelEmbeddings
24
24
  from .lora import LoRACompatibleLinear
25
25
 
26
26
 
27
+ @maybe_allow_in_graph
28
+ class GatedSelfAttentionDense(nn.Module):
29
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
30
+ super().__init__()
31
+
32
+ # we need a linear projection since we need cat visual feature and obj feature
33
+ self.linear = nn.Linear(context_dim, query_dim)
34
+
35
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
36
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
37
+
38
+ self.norm1 = nn.LayerNorm(query_dim)
39
+ self.norm2 = nn.LayerNorm(query_dim)
40
+
41
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
42
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
43
+
44
+ self.enabled = True
45
+
46
+ def forward(self, x, objs):
47
+ if not self.enabled:
48
+ return x
49
+
50
+ n_visual = x.shape[1]
51
+ objs = self.linear(objs)
52
+
53
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
54
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
55
+
56
+ return x
57
+
58
+
27
59
  @maybe_allow_in_graph
28
60
  class BasicTransformerBlock(nn.Module):
29
61
  r"""
@@ -62,6 +94,7 @@ class BasicTransformerBlock(nn.Module):
62
94
  norm_elementwise_affine: bool = True,
63
95
  norm_type: str = "layer_norm",
64
96
  final_dropout: bool = False,
97
+ attention_type: str = "default",
65
98
  ):
66
99
  super().__init__()
67
100
  self.only_cross_attention = only_cross_attention
@@ -120,6 +153,10 @@ class BasicTransformerBlock(nn.Module):
120
153
  self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
121
154
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
122
155
 
156
+ # 4. Fuser
157
+ if attention_type == "gated":
158
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
159
+
123
160
  # let chunk size default to None
124
161
  self._chunk_size = None
125
162
  self._chunk_dim = 0
@@ -150,7 +187,9 @@ class BasicTransformerBlock(nn.Module):
150
187
  else:
151
188
  norm_hidden_states = self.norm1(hidden_states)
152
189
 
153
- cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
190
+ # 0. Prepare GLIGEN inputs
191
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
192
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
154
193
 
155
194
  attn_output = self.attn1(
156
195
  norm_hidden_states,
@@ -162,6 +201,11 @@ class BasicTransformerBlock(nn.Module):
162
201
  attn_output = gate_msa.unsqueeze(1) * attn_output
163
202
  hidden_states = attn_output + hidden_states
164
203
 
204
+ # 1.5 GLIGEN Control
205
+ if gligen_kwargs is not None:
206
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
207
+ # 1.5 ends
208
+
165
209
  # 2. Cross-Attention
166
210
  if self.attn2 is not None:
167
211
  norm_hidden_states = (