diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +9 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +1 -1
- diffusers/loaders/single_file_model.py +5 -0
- diffusers/loaders/single_file_utils.py +242 -2
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +5 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +128 -17
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +344 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import inspect
|
18
18
|
import itertools
|
19
|
+
import json
|
19
20
|
import os
|
20
21
|
import re
|
21
22
|
from collections import OrderedDict
|
@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|
25
26
|
|
26
27
|
import safetensors
|
27
28
|
import torch
|
28
|
-
from huggingface_hub import create_repo
|
29
|
+
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
29
30
|
from huggingface_hub.utils import validate_hf_hub_args
|
30
31
|
from torch import Tensor, nn
|
31
32
|
|
@@ -33,9 +34,12 @@ from .. import __version__
|
|
33
34
|
from ..utils import (
|
34
35
|
CONFIG_NAME,
|
35
36
|
FLAX_WEIGHTS_NAME,
|
37
|
+
SAFE_WEIGHTS_INDEX_NAME,
|
36
38
|
SAFETENSORS_WEIGHTS_NAME,
|
39
|
+
WEIGHTS_INDEX_NAME,
|
37
40
|
WEIGHTS_NAME,
|
38
41
|
_add_variant,
|
42
|
+
_get_checkpoint_shard_files,
|
39
43
|
_get_model_file,
|
40
44
|
deprecate,
|
41
45
|
is_accelerate_available,
|
@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
|
|
49
53
|
)
|
50
54
|
from .model_loading_utils import (
|
51
55
|
_determine_device_map,
|
56
|
+
_fetch_index_file,
|
52
57
|
_load_state_dict_into_model,
|
53
58
|
load_model_dict_into_meta,
|
54
59
|
load_state_dict,
|
@@ -57,6 +62,8 @@ from .model_loading_utils import (
|
|
57
62
|
|
58
63
|
logger = logging.get_logger(__name__)
|
59
64
|
|
65
|
+
_REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
66
|
+
|
60
67
|
|
61
68
|
if is_torch_version(">=", "1.9.0"):
|
62
69
|
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
263
270
|
save_function: Optional[Callable] = None,
|
264
271
|
safe_serialization: bool = True,
|
265
272
|
variant: Optional[str] = None,
|
273
|
+
max_shard_size: Union[int, str] = "10GB",
|
266
274
|
push_to_hub: bool = False,
|
267
275
|
**kwargs,
|
268
276
|
):
|
@@ -285,6 +293,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
285
293
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
286
294
|
variant (`str`, *optional*):
|
287
295
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
296
|
+
max_shard_size (`int` or `str`, defaults to `"10GB"`):
|
297
|
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
298
|
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
299
|
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
300
|
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
301
|
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
302
|
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
288
303
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
289
304
|
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
290
305
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
@@ -296,6 +311,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
296
311
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
297
312
|
return
|
298
313
|
|
314
|
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
315
|
+
weights_name = _add_variant(weights_name, variant)
|
316
|
+
weight_name_split = weights_name.split(".")
|
317
|
+
if len(weight_name_split) in [2, 3]:
|
318
|
+
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
|
319
|
+
else:
|
320
|
+
raise ValueError(f"Invalid {weights_name} provided.")
|
321
|
+
|
299
322
|
os.makedirs(save_directory, exist_ok=True)
|
300
323
|
|
301
324
|
if push_to_hub:
|
@@ -317,18 +340,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
317
340
|
# Save the model
|
318
341
|
state_dict = model_to_save.state_dict()
|
319
342
|
|
320
|
-
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
321
|
-
weights_name = _add_variant(weights_name, variant)
|
322
|
-
|
323
343
|
# Save the model
|
324
|
-
|
325
|
-
|
326
|
-
|
344
|
+
state_dict_split = split_torch_state_dict_into_shards(
|
345
|
+
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
346
|
+
)
|
347
|
+
|
348
|
+
# Clean the folder from a previous save
|
349
|
+
if is_main_process:
|
350
|
+
for filename in os.listdir(save_directory):
|
351
|
+
if filename in state_dict_split.filename_to_tensors.keys():
|
352
|
+
continue
|
353
|
+
full_filename = os.path.join(save_directory, filename)
|
354
|
+
if not os.path.isfile(full_filename):
|
355
|
+
continue
|
356
|
+
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
357
|
+
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
358
|
+
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
359
|
+
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
360
|
+
if (
|
361
|
+
filename.startswith(weights_without_ext)
|
362
|
+
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
363
|
+
):
|
364
|
+
os.remove(full_filename)
|
365
|
+
|
366
|
+
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
367
|
+
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
368
|
+
filepath = os.path.join(save_directory, filename)
|
369
|
+
if safe_serialization:
|
370
|
+
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
371
|
+
# joyfulness), but for now this enough.
|
372
|
+
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
373
|
+
else:
|
374
|
+
torch.save(shard, filepath)
|
375
|
+
|
376
|
+
if state_dict_split.is_sharded:
|
377
|
+
index = {
|
378
|
+
"metadata": state_dict_split.metadata,
|
379
|
+
"weight_map": state_dict_split.tensor_to_filename,
|
380
|
+
}
|
381
|
+
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
382
|
+
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
383
|
+
# Save the index as well
|
384
|
+
with open(save_index_file, "w", encoding="utf-8") as f:
|
385
|
+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
386
|
+
f.write(content)
|
387
|
+
logger.info(
|
388
|
+
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
389
|
+
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
390
|
+
f"index located at {save_index_file}."
|
327
391
|
)
|
328
392
|
else:
|
329
|
-
|
330
|
-
|
331
|
-
logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
|
393
|
+
path_to_weights = os.path.join(save_directory, weights_name)
|
394
|
+
logger.info(f"Model weights saved in {path_to_weights}")
|
332
395
|
|
333
396
|
if push_to_hub:
|
334
397
|
# Create a new empty model card and eventually tag it
|
@@ -566,6 +629,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
566
629
|
**kwargs,
|
567
630
|
)
|
568
631
|
|
632
|
+
# Determine if we're loading from a directory of sharded checkpoints.
|
633
|
+
is_sharded = False
|
634
|
+
index_file = None
|
635
|
+
is_local = os.path.isdir(pretrained_model_name_or_path)
|
636
|
+
index_file = _fetch_index_file(
|
637
|
+
is_local=is_local,
|
638
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
639
|
+
subfolder=subfolder or "",
|
640
|
+
use_safetensors=use_safetensors,
|
641
|
+
cache_dir=cache_dir,
|
642
|
+
variant=variant,
|
643
|
+
force_download=force_download,
|
644
|
+
resume_download=resume_download,
|
645
|
+
proxies=proxies,
|
646
|
+
local_files_only=local_files_only,
|
647
|
+
token=token,
|
648
|
+
revision=revision,
|
649
|
+
user_agent=user_agent,
|
650
|
+
commit_hash=commit_hash,
|
651
|
+
)
|
652
|
+
if index_file is not None and index_file.is_file():
|
653
|
+
is_sharded = True
|
654
|
+
|
655
|
+
if is_sharded and from_flax:
|
656
|
+
raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
|
657
|
+
|
569
658
|
# load model
|
570
659
|
model_file = None
|
571
660
|
if from_flax:
|
@@ -590,7 +679,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
590
679
|
|
591
680
|
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
592
681
|
else:
|
593
|
-
if
|
682
|
+
if is_sharded:
|
683
|
+
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
684
|
+
pretrained_model_name_or_path,
|
685
|
+
index_file,
|
686
|
+
cache_dir=cache_dir,
|
687
|
+
proxies=proxies,
|
688
|
+
resume_download=resume_download,
|
689
|
+
local_files_only=local_files_only,
|
690
|
+
token=token,
|
691
|
+
user_agent=user_agent,
|
692
|
+
revision=revision,
|
693
|
+
subfolder=subfolder or "",
|
694
|
+
)
|
695
|
+
|
696
|
+
elif use_safetensors and not is_sharded:
|
594
697
|
try:
|
595
698
|
model_file = _get_model_file(
|
596
699
|
pretrained_model_name_or_path,
|
@@ -606,11 +709,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
606
709
|
user_agent=user_agent,
|
607
710
|
commit_hash=commit_hash,
|
608
711
|
)
|
712
|
+
|
609
713
|
except IOError as e:
|
714
|
+
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
610
715
|
if not allow_pickle:
|
611
|
-
raise
|
612
|
-
|
613
|
-
|
716
|
+
raise
|
717
|
+
logger.warning(
|
718
|
+
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
719
|
+
)
|
720
|
+
|
721
|
+
if model_file is None and not is_sharded:
|
614
722
|
model_file = _get_model_file(
|
615
723
|
pretrained_model_name_or_path,
|
616
724
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
@@ -632,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
632
740
|
model = cls.from_config(config, **unused_kwargs)
|
633
741
|
|
634
742
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
635
|
-
if device_map is None:
|
743
|
+
if device_map is None and not is_sharded:
|
636
744
|
param_device = "cpu"
|
637
745
|
state_dict = load_state_dict(model_file, variant=variant)
|
638
746
|
model._convert_deprecated_attention_blocks(state_dict)
|
@@ -670,7 +778,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
670
778
|
try:
|
671
779
|
accelerate.load_checkpoint_and_dispatch(
|
672
780
|
model,
|
673
|
-
model_file,
|
781
|
+
model_file if not is_sharded else sharded_ckpt_cached_folder,
|
674
782
|
device_map,
|
675
783
|
max_memory=max_memory,
|
676
784
|
offload_folder=offload_folder,
|
@@ -1057,6 +1165,9 @@ class LegacyModelMixin(ModelMixin):
|
|
1057
1165
|
# To prevent depedency import problem.
|
1058
1166
|
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1059
1167
|
|
1168
|
+
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
1169
|
+
kwargs_copy = kwargs.copy()
|
1170
|
+
|
1060
1171
|
cache_dir = kwargs.pop("cache_dir", None)
|
1061
1172
|
force_download = kwargs.pop("force_download", False)
|
1062
1173
|
resume_download = kwargs.pop("resume_download", None)
|
@@ -1094,4 +1205,4 @@ class LegacyModelMixin(ModelMixin):
|
|
1094
1205
|
# resolve remapping
|
1095
1206
|
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1096
1207
|
|
1097
|
-
return remapped_class.from_pretrained(pretrained_model_name_or_path, **
|
1208
|
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
@@ -57,10 +57,12 @@ class AdaLayerNormZero(nn.Module):
|
|
57
57
|
num_embeddings (`int`): The size of the embeddings dictionary.
|
58
58
|
"""
|
59
59
|
|
60
|
-
def __init__(self, embedding_dim: int, num_embeddings: int):
|
60
|
+
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
61
61
|
super().__init__()
|
62
|
-
|
63
|
-
|
62
|
+
if num_embeddings is not None:
|
63
|
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
64
|
+
else:
|
65
|
+
self.emb = None
|
64
66
|
|
65
67
|
self.silu = nn.SiLU()
|
66
68
|
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
@@ -69,11 +71,14 @@ class AdaLayerNormZero(nn.Module):
|
|
69
71
|
def forward(
|
70
72
|
self,
|
71
73
|
x: torch.Tensor,
|
72
|
-
timestep: torch.Tensor,
|
73
|
-
class_labels: torch.LongTensor,
|
74
|
+
timestep: Optional[torch.Tensor] = None,
|
75
|
+
class_labels: Optional[torch.LongTensor] = None,
|
74
76
|
hidden_dtype: Optional[torch.dtype] = None,
|
77
|
+
emb: Optional[torch.Tensor] = None,
|
75
78
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
76
|
-
|
79
|
+
if self.emb is not None:
|
80
|
+
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
81
|
+
emb = self.linear(self.silu(emb))
|
77
82
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
78
83
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
79
84
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
@@ -9,4 +9,5 @@ if is_torch_available():
|
|
9
9
|
from .prior_transformer import PriorTransformer
|
10
10
|
from .t5_film_transformer import T5FilmDecoder
|
11
11
|
from .transformer_2d import Transformer2DModel
|
12
|
+
from .transformer_sd3 import SD3Transformer2DModel
|
12
13
|
from .transformer_temporal import TransformerTemporalModel
|
@@ -15,7 +15,8 @@ from typing import Optional
|
|
15
15
|
|
16
16
|
from torch import nn
|
17
17
|
|
18
|
-
from
|
18
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
19
|
+
from .transformer_2d import Transformer2DModel
|
19
20
|
|
20
21
|
|
21
22
|
class DualTransformer2DModel(nn.Module):
|
@@ -123,9 +124,9 @@ class DualTransformer2DModel(nn.Module):
|
|
123
124
|
tuple.
|
124
125
|
|
125
126
|
Returns:
|
126
|
-
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
127
|
-
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a
|
128
|
-
returning a tuple, the first element is the sample tensor.
|
127
|
+
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
128
|
+
[`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a
|
129
|
+
`tuple`. When returning a tuple, the first element is the sample tensor.
|
129
130
|
"""
|
130
131
|
input_states = hidden_states
|
131
132
|
|
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
from typing import Optional
|
14
|
+
from typing import Dict, Optional, Union
|
15
15
|
|
16
16
|
import torch
|
17
17
|
import torch.nn.functional as F
|
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
|
21
21
|
from ...utils import logging
|
22
22
|
from ...utils.torch_utils import maybe_allow_in_graph
|
23
23
|
from ..attention import FeedForward
|
24
|
-
from ..attention_processor import Attention, HunyuanAttnProcessor2_0
|
24
|
+
from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
|
25
25
|
from ..embeddings import (
|
26
26
|
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
27
27
|
PatchEmbed,
|
@@ -166,6 +166,7 @@ class HunyuanDiTBlock(nn.Module):
|
|
166
166
|
self._chunk_size = None
|
167
167
|
self._chunk_dim = 0
|
168
168
|
|
169
|
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
169
170
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
170
171
|
# Sets chunk feed-forward
|
171
172
|
self._chunk_size = chunk_size
|
@@ -321,6 +322,110 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|
321
322
|
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
322
323
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
323
324
|
|
325
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
326
|
+
def fuse_qkv_projections(self):
|
327
|
+
"""
|
328
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
329
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
330
|
+
|
331
|
+
<Tip warning={true}>
|
332
|
+
|
333
|
+
This API is 🧪 experimental.
|
334
|
+
|
335
|
+
</Tip>
|
336
|
+
"""
|
337
|
+
self.original_attn_processors = None
|
338
|
+
|
339
|
+
for _, attn_processor in self.attn_processors.items():
|
340
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
341
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
342
|
+
|
343
|
+
self.original_attn_processors = self.attn_processors
|
344
|
+
|
345
|
+
for module in self.modules():
|
346
|
+
if isinstance(module, Attention):
|
347
|
+
module.fuse_projections(fuse=True)
|
348
|
+
|
349
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
350
|
+
def unfuse_qkv_projections(self):
|
351
|
+
"""Disables the fused QKV projection if enabled.
|
352
|
+
|
353
|
+
<Tip warning={true}>
|
354
|
+
|
355
|
+
This API is 🧪 experimental.
|
356
|
+
|
357
|
+
</Tip>
|
358
|
+
|
359
|
+
"""
|
360
|
+
if self.original_attn_processors is not None:
|
361
|
+
self.set_attn_processor(self.original_attn_processors)
|
362
|
+
|
363
|
+
@property
|
364
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
365
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
366
|
+
r"""
|
367
|
+
Returns:
|
368
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
369
|
+
indexed by its weight name.
|
370
|
+
"""
|
371
|
+
# set recursively
|
372
|
+
processors = {}
|
373
|
+
|
374
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
375
|
+
if hasattr(module, "get_processor"):
|
376
|
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
377
|
+
|
378
|
+
for sub_name, child in module.named_children():
|
379
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
380
|
+
|
381
|
+
return processors
|
382
|
+
|
383
|
+
for name, module in self.named_children():
|
384
|
+
fn_recursive_add_processors(name, module, processors)
|
385
|
+
|
386
|
+
return processors
|
387
|
+
|
388
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
389
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
390
|
+
r"""
|
391
|
+
Sets the attention processor to use to compute attention.
|
392
|
+
|
393
|
+
Parameters:
|
394
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
395
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
396
|
+
for **all** `Attention` layers.
|
397
|
+
|
398
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
399
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
400
|
+
|
401
|
+
"""
|
402
|
+
count = len(self.attn_processors.keys())
|
403
|
+
|
404
|
+
if isinstance(processor, dict) and len(processor) != count:
|
405
|
+
raise ValueError(
|
406
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
407
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
408
|
+
)
|
409
|
+
|
410
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
411
|
+
if hasattr(module, "set_processor"):
|
412
|
+
if not isinstance(processor, dict):
|
413
|
+
module.set_processor(processor)
|
414
|
+
else:
|
415
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
416
|
+
|
417
|
+
for sub_name, child in module.named_children():
|
418
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
419
|
+
|
420
|
+
for name, module in self.named_children():
|
421
|
+
fn_recursive_attn_processor(name, module, processor)
|
422
|
+
|
423
|
+
def set_default_attn_processor(self):
|
424
|
+
"""
|
425
|
+
Disables custom attention processors and sets the default attention implementation.
|
426
|
+
"""
|
427
|
+
self.set_attn_processor(HunyuanAttnProcessor2_0())
|
428
|
+
|
324
429
|
def forward(
|
325
430
|
self,
|
326
431
|
hidden_states,
|
@@ -425,3 +530,45 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
|
425
530
|
if not return_dict:
|
426
531
|
return (output,)
|
427
532
|
return Transformer2DModelOutput(sample=output)
|
533
|
+
|
534
|
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
535
|
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
536
|
+
"""
|
537
|
+
Sets the attention processor to use [feed forward
|
538
|
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
539
|
+
|
540
|
+
Parameters:
|
541
|
+
chunk_size (`int`, *optional*):
|
542
|
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
543
|
+
over each tensor of dim=`dim`.
|
544
|
+
dim (`int`, *optional*, defaults to `0`):
|
545
|
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
546
|
+
or dim=1 (sequence length).
|
547
|
+
"""
|
548
|
+
if dim not in [0, 1]:
|
549
|
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
550
|
+
|
551
|
+
# By default chunk size is 1
|
552
|
+
chunk_size = chunk_size or 1
|
553
|
+
|
554
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
555
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
556
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
557
|
+
|
558
|
+
for child in module.children():
|
559
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
560
|
+
|
561
|
+
for module in self.children():
|
562
|
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
563
|
+
|
564
|
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
565
|
+
def disable_forward_chunking(self):
|
566
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
567
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
568
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
569
|
+
|
570
|
+
for child in module.children():
|
571
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
572
|
+
|
573
|
+
for module in self.children():
|
574
|
+
fn_recursive_feed_forward(module, None, 0)
|
@@ -266,13 +266,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
|
266
266
|
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
267
267
|
Text mask for the text embeddings.
|
268
268
|
return_dict (`bool`, *optional*, defaults to `True`):
|
269
|
-
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of
|
270
|
-
tuple.
|
269
|
+
Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of
|
270
|
+
a plain tuple.
|
271
271
|
|
272
272
|
Returns:
|
273
|
-
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
274
|
-
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is
|
275
|
-
tuple is returned where the first element is the sample tensor.
|
273
|
+
[`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
274
|
+
If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is
|
275
|
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
276
276
|
"""
|
277
277
|
batch_size = hidden_states.shape[0]
|
278
278
|
|
@@ -369,8 +369,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
|
369
369
|
tuple.
|
370
370
|
|
371
371
|
Returns:
|
372
|
-
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned,
|
373
|
-
`tuple` where the first element is the sample tensor.
|
372
|
+
If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
|
373
|
+
otherwise a `tuple` where the first element is the sample tensor.
|
374
374
|
"""
|
375
375
|
if cross_attention_kwargs is not None:
|
376
376
|
if cross_attention_kwargs.get("scale", None) is not None:
|