diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +1 -1
  9. diffusers/loaders/single_file_model.py +5 -0
  10. diffusers/loaders/single_file_utils.py +242 -2
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +5 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_xs.py +6 -6
  22. diffusers/models/embeddings.py +112 -84
  23. diffusers/models/model_loading_utils.py +55 -0
  24. diffusers/models/modeling_utils.py +128 -17
  25. diffusers/models/normalization.py +11 -6
  26. diffusers/models/transformers/__init__.py +1 -0
  27. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  28. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  29. diffusers/models/transformers/prior_transformer.py +5 -5
  30. diffusers/models/transformers/transformer_2d.py +2 -2
  31. diffusers/models/transformers/transformer_sd3.py +344 -0
  32. diffusers/models/transformers/transformer_temporal.py +12 -10
  33. diffusers/models/unets/unet_1d.py +3 -3
  34. diffusers/models/unets/unet_2d.py +3 -3
  35. diffusers/models/unets/unet_2d_condition.py +4 -15
  36. diffusers/models/unets/unet_3d_condition.py +5 -17
  37. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  38. diffusers/models/unets/unet_motion_model.py +4 -4
  39. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  40. diffusers/models/vq_model.py +8 -165
  41. diffusers/pipelines/__init__.py +2 -0
  42. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  43. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  44. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  45. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  46. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  47. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  48. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  49. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  50. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  51. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  52. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  54. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  55. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  56. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  57. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  58. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  59. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  60. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  61. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  69. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  70. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  71. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
  72. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
  73. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  74. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  75. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  76. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  77. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  78. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  79. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  80. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  81. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  82. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  83. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  84. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  85. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  86. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  87. diffusers/schedulers/__init__.py +2 -0
  88. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  89. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  90. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  91. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  92. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  93. diffusers/training_utils.py +4 -4
  94. diffusers/utils/__init__.py +3 -0
  95. diffusers/utils/constants.py +2 -0
  96. diffusers/utils/dummy_pt_objects.py +30 -0
  97. diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
  98. diffusers/utils/dynamic_modules_utils.py +15 -13
  99. diffusers/utils/hub_utils.py +106 -0
  100. diffusers/utils/import_utils.py +0 -1
  101. diffusers/utils/logging.py +3 -1
  102. diffusers/utils/state_dict_utils.py +2 -0
  103. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
  104. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
  105. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
  106. diffusers/models/dual_transformer_2d.py +0 -20
  107. diffusers/models/prior_transformer.py +0 -12
  108. diffusers/models/t5_film_transformer.py +0 -70
  109. diffusers/models/transformer_2d.py +0 -25
  110. diffusers/models/transformer_temporal.py +0 -34
  111. diffusers/models/unet_1d.py +0 -26
  112. diffusers/models/unet_1d_blocks.py +0 -203
  113. diffusers/models/unet_2d.py +0 -27
  114. diffusers/models/unet_2d_blocks.py +0 -375
  115. diffusers/models/unet_2d_condition.py +0 -25
  116. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
  117. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
  118. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
 
17
17
  import inspect
18
18
  import itertools
19
+ import json
19
20
  import os
20
21
  import re
21
22
  from collections import OrderedDict
@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
25
26
 
26
27
  import safetensors
27
28
  import torch
28
- from huggingface_hub import create_repo
29
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
29
30
  from huggingface_hub.utils import validate_hf_hub_args
30
31
  from torch import Tensor, nn
31
32
 
@@ -33,9 +34,12 @@ from .. import __version__
33
34
  from ..utils import (
34
35
  CONFIG_NAME,
35
36
  FLAX_WEIGHTS_NAME,
37
+ SAFE_WEIGHTS_INDEX_NAME,
36
38
  SAFETENSORS_WEIGHTS_NAME,
39
+ WEIGHTS_INDEX_NAME,
37
40
  WEIGHTS_NAME,
38
41
  _add_variant,
42
+ _get_checkpoint_shard_files,
39
43
  _get_model_file,
40
44
  deprecate,
41
45
  is_accelerate_available,
@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
49
53
  )
50
54
  from .model_loading_utils import (
51
55
  _determine_device_map,
56
+ _fetch_index_file,
52
57
  _load_state_dict_into_model,
53
58
  load_model_dict_into_meta,
54
59
  load_state_dict,
@@ -57,6 +62,8 @@ from .model_loading_utils import (
57
62
 
58
63
  logger = logging.get_logger(__name__)
59
64
 
65
+ _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
66
+
60
67
 
61
68
  if is_torch_version(">=", "1.9.0"):
62
69
  _LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
263
270
  save_function: Optional[Callable] = None,
264
271
  safe_serialization: bool = True,
265
272
  variant: Optional[str] = None,
273
+ max_shard_size: Union[int, str] = "10GB",
266
274
  push_to_hub: bool = False,
267
275
  **kwargs,
268
276
  ):
@@ -285,6 +293,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
285
293
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
286
294
  variant (`str`, *optional*):
287
295
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
296
+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
297
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
298
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
299
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
300
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
301
+ This is to establish a common default size for this argument across different libraries in the Hugging
302
+ Face ecosystem (`transformers`, and `accelerate`, for example).
288
303
  push_to_hub (`bool`, *optional*, defaults to `False`):
289
304
  Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
290
305
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -296,6 +311,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
296
311
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
297
312
  return
298
313
 
314
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
315
+ weights_name = _add_variant(weights_name, variant)
316
+ weight_name_split = weights_name.split(".")
317
+ if len(weight_name_split) in [2, 3]:
318
+ weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
319
+ else:
320
+ raise ValueError(f"Invalid {weights_name} provided.")
321
+
299
322
  os.makedirs(save_directory, exist_ok=True)
300
323
 
301
324
  if push_to_hub:
@@ -317,18 +340,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
317
340
  # Save the model
318
341
  state_dict = model_to_save.state_dict()
319
342
 
320
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321
- weights_name = _add_variant(weights_name, variant)
322
-
323
343
  # Save the model
324
- if safe_serialization:
325
- safetensors.torch.save_file(
326
- state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
344
+ state_dict_split = split_torch_state_dict_into_shards(
345
+ state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
346
+ )
347
+
348
+ # Clean the folder from a previous save
349
+ if is_main_process:
350
+ for filename in os.listdir(save_directory):
351
+ if filename in state_dict_split.filename_to_tensors.keys():
352
+ continue
353
+ full_filename = os.path.join(save_directory, filename)
354
+ if not os.path.isfile(full_filename):
355
+ continue
356
+ weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
357
+ weights_without_ext = weights_without_ext.replace("{suffix}", "")
358
+ filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
359
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
360
+ if (
361
+ filename.startswith(weights_without_ext)
362
+ and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
363
+ ):
364
+ os.remove(full_filename)
365
+
366
+ for filename, tensors in state_dict_split.filename_to_tensors.items():
367
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
368
+ filepath = os.path.join(save_directory, filename)
369
+ if safe_serialization:
370
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
371
+ # joyfulness), but for now this enough.
372
+ safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
373
+ else:
374
+ torch.save(shard, filepath)
375
+
376
+ if state_dict_split.is_sharded:
377
+ index = {
378
+ "metadata": state_dict_split.metadata,
379
+ "weight_map": state_dict_split.tensor_to_filename,
380
+ }
381
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
382
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
383
+ # Save the index as well
384
+ with open(save_index_file, "w", encoding="utf-8") as f:
385
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
386
+ f.write(content)
387
+ logger.info(
388
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
389
+ f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
390
+ f"index located at {save_index_file}."
327
391
  )
328
392
  else:
329
- torch.save(state_dict, Path(save_directory, weights_name).as_posix())
330
-
331
- logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
393
+ path_to_weights = os.path.join(save_directory, weights_name)
394
+ logger.info(f"Model weights saved in {path_to_weights}")
332
395
 
333
396
  if push_to_hub:
334
397
  # Create a new empty model card and eventually tag it
@@ -566,6 +629,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
566
629
  **kwargs,
567
630
  )
568
631
 
632
+ # Determine if we're loading from a directory of sharded checkpoints.
633
+ is_sharded = False
634
+ index_file = None
635
+ is_local = os.path.isdir(pretrained_model_name_or_path)
636
+ index_file = _fetch_index_file(
637
+ is_local=is_local,
638
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
639
+ subfolder=subfolder or "",
640
+ use_safetensors=use_safetensors,
641
+ cache_dir=cache_dir,
642
+ variant=variant,
643
+ force_download=force_download,
644
+ resume_download=resume_download,
645
+ proxies=proxies,
646
+ local_files_only=local_files_only,
647
+ token=token,
648
+ revision=revision,
649
+ user_agent=user_agent,
650
+ commit_hash=commit_hash,
651
+ )
652
+ if index_file is not None and index_file.is_file():
653
+ is_sharded = True
654
+
655
+ if is_sharded and from_flax:
656
+ raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
657
+
569
658
  # load model
570
659
  model_file = None
571
660
  if from_flax:
@@ -590,7 +679,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
590
679
 
591
680
  model = load_flax_checkpoint_in_pytorch_model(model, model_file)
592
681
  else:
593
- if use_safetensors:
682
+ if is_sharded:
683
+ sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
684
+ pretrained_model_name_or_path,
685
+ index_file,
686
+ cache_dir=cache_dir,
687
+ proxies=proxies,
688
+ resume_download=resume_download,
689
+ local_files_only=local_files_only,
690
+ token=token,
691
+ user_agent=user_agent,
692
+ revision=revision,
693
+ subfolder=subfolder or "",
694
+ )
695
+
696
+ elif use_safetensors and not is_sharded:
594
697
  try:
595
698
  model_file = _get_model_file(
596
699
  pretrained_model_name_or_path,
@@ -606,11 +709,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
606
709
  user_agent=user_agent,
607
710
  commit_hash=commit_hash,
608
711
  )
712
+
609
713
  except IOError as e:
714
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
610
715
  if not allow_pickle:
611
- raise e
612
- pass
613
- if model_file is None:
716
+ raise
717
+ logger.warning(
718
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
719
+ )
720
+
721
+ if model_file is None and not is_sharded:
614
722
  model_file = _get_model_file(
615
723
  pretrained_model_name_or_path,
616
724
  weights_name=_add_variant(WEIGHTS_NAME, variant),
@@ -632,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
632
740
  model = cls.from_config(config, **unused_kwargs)
633
741
 
634
742
  # if device_map is None, load the state dict and move the params from meta device to the cpu
635
- if device_map is None:
743
+ if device_map is None and not is_sharded:
636
744
  param_device = "cpu"
637
745
  state_dict = load_state_dict(model_file, variant=variant)
638
746
  model._convert_deprecated_attention_blocks(state_dict)
@@ -670,7 +778,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
670
778
  try:
671
779
  accelerate.load_checkpoint_and_dispatch(
672
780
  model,
673
- model_file,
781
+ model_file if not is_sharded else sharded_ckpt_cached_folder,
674
782
  device_map,
675
783
  max_memory=max_memory,
676
784
  offload_folder=offload_folder,
@@ -1057,6 +1165,9 @@ class LegacyModelMixin(ModelMixin):
1057
1165
  # To prevent depedency import problem.
1058
1166
  from .model_loading_utils import _fetch_remapped_cls_from_config
1059
1167
 
1168
+ # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
1169
+ kwargs_copy = kwargs.copy()
1170
+
1060
1171
  cache_dir = kwargs.pop("cache_dir", None)
1061
1172
  force_download = kwargs.pop("force_download", False)
1062
1173
  resume_download = kwargs.pop("resume_download", None)
@@ -1094,4 +1205,4 @@ class LegacyModelMixin(ModelMixin):
1094
1205
  # resolve remapping
1095
1206
  remapped_class = _fetch_remapped_cls_from_config(config, cls)
1096
1207
 
1097
- return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1208
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
@@ -57,10 +57,12 @@ class AdaLayerNormZero(nn.Module):
57
57
  num_embeddings (`int`): The size of the embeddings dictionary.
58
58
  """
59
59
 
60
- def __init__(self, embedding_dim: int, num_embeddings: int):
60
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
61
61
  super().__init__()
62
-
63
- self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
62
+ if num_embeddings is not None:
63
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
64
+ else:
65
+ self.emb = None
64
66
 
65
67
  self.silu = nn.SiLU()
66
68
  self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
@@ -69,11 +71,14 @@ class AdaLayerNormZero(nn.Module):
69
71
  def forward(
70
72
  self,
71
73
  x: torch.Tensor,
72
- timestep: torch.Tensor,
73
- class_labels: torch.LongTensor,
74
+ timestep: Optional[torch.Tensor] = None,
75
+ class_labels: Optional[torch.LongTensor] = None,
74
76
  hidden_dtype: Optional[torch.dtype] = None,
77
+ emb: Optional[torch.Tensor] = None,
75
78
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
76
- emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
79
+ if self.emb is not None:
80
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
81
+ emb = self.linear(self.silu(emb))
77
82
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
78
83
  x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
79
84
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@@ -9,4 +9,5 @@ if is_torch_available():
9
9
  from .prior_transformer import PriorTransformer
10
10
  from .t5_film_transformer import T5FilmDecoder
11
11
  from .transformer_2d import Transformer2DModel
12
+ from .transformer_sd3 import SD3Transformer2DModel
12
13
  from .transformer_temporal import TransformerTemporalModel
@@ -15,7 +15,8 @@ from typing import Optional
15
15
 
16
16
  from torch import nn
17
17
 
18
- from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
18
+ from ..modeling_outputs import Transformer2DModelOutput
19
+ from .transformer_2d import Transformer2DModel
19
20
 
20
21
 
21
22
  class DualTransformer2DModel(nn.Module):
@@ -123,9 +124,9 @@ class DualTransformer2DModel(nn.Module):
123
124
  tuple.
124
125
 
125
126
  Returns:
126
- [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
127
- [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
128
- returning a tuple, the first element is the sample tensor.
127
+ [`~models.transformers.transformer_2d.Transformer2DModelOutput`] or `tuple`:
128
+ [`~models.transformers.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a
129
+ `tuple`. When returning a tuple, the first element is the sample tensor.
129
130
  """
130
131
  input_states = hidden_states
131
132
 
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Optional
14
+ from typing import Dict, Optional, Union
15
15
 
16
16
  import torch
17
17
  import torch.nn.functional as F
@@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
21
21
  from ...utils import logging
22
22
  from ...utils.torch_utils import maybe_allow_in_graph
23
23
  from ..attention import FeedForward
24
- from ..attention_processor import Attention, HunyuanAttnProcessor2_0
24
+ from ..attention_processor import Attention, AttentionProcessor, HunyuanAttnProcessor2_0
25
25
  from ..embeddings import (
26
26
  HunyuanCombinedTimestepTextSizeStyleEmbedding,
27
27
  PatchEmbed,
@@ -166,6 +166,7 @@ class HunyuanDiTBlock(nn.Module):
166
166
  self._chunk_size = None
167
167
  self._chunk_dim = 0
168
168
 
169
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
169
170
  def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
170
171
  # Sets chunk feed-forward
171
172
  self._chunk_size = chunk_size
@@ -321,6 +322,110 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
321
322
  self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
322
323
  self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
323
324
 
325
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
326
+ def fuse_qkv_projections(self):
327
+ """
328
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
329
+ are fused. For cross-attention modules, key and value projection matrices are fused.
330
+
331
+ <Tip warning={true}>
332
+
333
+ This API is 🧪 experimental.
334
+
335
+ </Tip>
336
+ """
337
+ self.original_attn_processors = None
338
+
339
+ for _, attn_processor in self.attn_processors.items():
340
+ if "Added" in str(attn_processor.__class__.__name__):
341
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
342
+
343
+ self.original_attn_processors = self.attn_processors
344
+
345
+ for module in self.modules():
346
+ if isinstance(module, Attention):
347
+ module.fuse_projections(fuse=True)
348
+
349
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
350
+ def unfuse_qkv_projections(self):
351
+ """Disables the fused QKV projection if enabled.
352
+
353
+ <Tip warning={true}>
354
+
355
+ This API is 🧪 experimental.
356
+
357
+ </Tip>
358
+
359
+ """
360
+ if self.original_attn_processors is not None:
361
+ self.set_attn_processor(self.original_attn_processors)
362
+
363
+ @property
364
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
365
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
366
+ r"""
367
+ Returns:
368
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
369
+ indexed by its weight name.
370
+ """
371
+ # set recursively
372
+ processors = {}
373
+
374
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
375
+ if hasattr(module, "get_processor"):
376
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
377
+
378
+ for sub_name, child in module.named_children():
379
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
380
+
381
+ return processors
382
+
383
+ for name, module in self.named_children():
384
+ fn_recursive_add_processors(name, module, processors)
385
+
386
+ return processors
387
+
388
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
389
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
390
+ r"""
391
+ Sets the attention processor to use to compute attention.
392
+
393
+ Parameters:
394
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
395
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
396
+ for **all** `Attention` layers.
397
+
398
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
399
+ processor. This is strongly recommended when setting trainable attention processors.
400
+
401
+ """
402
+ count = len(self.attn_processors.keys())
403
+
404
+ if isinstance(processor, dict) and len(processor) != count:
405
+ raise ValueError(
406
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
407
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
408
+ )
409
+
410
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
411
+ if hasattr(module, "set_processor"):
412
+ if not isinstance(processor, dict):
413
+ module.set_processor(processor)
414
+ else:
415
+ module.set_processor(processor.pop(f"{name}.processor"))
416
+
417
+ for sub_name, child in module.named_children():
418
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
419
+
420
+ for name, module in self.named_children():
421
+ fn_recursive_attn_processor(name, module, processor)
422
+
423
+ def set_default_attn_processor(self):
424
+ """
425
+ Disables custom attention processors and sets the default attention implementation.
426
+ """
427
+ self.set_attn_processor(HunyuanAttnProcessor2_0())
428
+
324
429
  def forward(
325
430
  self,
326
431
  hidden_states,
@@ -425,3 +530,45 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
425
530
  if not return_dict:
426
531
  return (output,)
427
532
  return Transformer2DModelOutput(sample=output)
533
+
534
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
535
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
536
+ """
537
+ Sets the attention processor to use [feed forward
538
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
539
+
540
+ Parameters:
541
+ chunk_size (`int`, *optional*):
542
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
543
+ over each tensor of dim=`dim`.
544
+ dim (`int`, *optional*, defaults to `0`):
545
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
546
+ or dim=1 (sequence length).
547
+ """
548
+ if dim not in [0, 1]:
549
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
550
+
551
+ # By default chunk size is 1
552
+ chunk_size = chunk_size or 1
553
+
554
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
555
+ if hasattr(module, "set_chunk_feed_forward"):
556
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
557
+
558
+ for child in module.children():
559
+ fn_recursive_feed_forward(child, chunk_size, dim)
560
+
561
+ for module in self.children():
562
+ fn_recursive_feed_forward(module, chunk_size, dim)
563
+
564
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
565
+ def disable_forward_chunking(self):
566
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
567
+ if hasattr(module, "set_chunk_feed_forward"):
568
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
569
+
570
+ for child in module.children():
571
+ fn_recursive_feed_forward(child, chunk_size, dim)
572
+
573
+ for module in self.children():
574
+ fn_recursive_feed_forward(module, None, 0)
@@ -266,13 +266,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
266
266
  attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
267
267
  Text mask for the text embeddings.
268
268
  return_dict (`bool`, *optional*, defaults to `True`):
269
- Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
270
- tuple.
269
+ Whether or not to return a [`~models.transformers.prior_transformer.PriorTransformerOutput`] instead of
270
+ a plain tuple.
271
271
 
272
272
  Returns:
273
- [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
274
- If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
275
- tuple is returned where the first element is the sample tensor.
273
+ [`~models.transformers.prior_transformer.PriorTransformerOutput`] or `tuple`:
274
+ If return_dict is True, a [`~models.transformers.prior_transformer.PriorTransformerOutput`] is
275
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
276
276
  """
277
277
  batch_size = hidden_states.shape[0]
278
278
 
@@ -369,8 +369,8 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
369
369
  tuple.
370
370
 
371
371
  Returns:
372
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
373
- `tuple` where the first element is the sample tensor.
372
+ If `return_dict` is True, an [`~models.transformers.transformer_2d.Transformer2DModelOutput`] is returned,
373
+ otherwise a `tuple` where the first element is the sample tensor.
374
374
  """
375
375
  if cross_attention_kwargs is not None:
376
376
  if cross_attention_kwargs.get("scale", None) is not None: