diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,10 @@ import torch.nn.functional as F
22
22
 
23
23
  from ..utils import is_torch_version
24
24
  from .activations import get_activation
25
- from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
25
+ from .embeddings import (
26
+ CombinedTimestepLabelEmbeddings,
27
+ PixArtAlphaCombinedTimestepSizeEmbeddings,
28
+ )
26
29
 
27
30
 
28
31
  class AdaLayerNorm(nn.Module):
@@ -31,23 +34,69 @@ class AdaLayerNorm(nn.Module):
31
34
 
32
35
  Parameters:
33
36
  embedding_dim (`int`): The size of each embedding vector.
34
- num_embeddings (`int`): The size of the embeddings dictionary.
37
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
38
+ output_dim (`int`, *optional*):
39
+ norm_elementwise_affine (`bool`, defaults to `False):
40
+ norm_eps (`bool`, defaults to `False`):
41
+ chunk_dim (`int`, defaults to `0`):
35
42
  """
36
43
 
37
- def __init__(self, embedding_dim: int, num_embeddings: int):
44
+ def __init__(
45
+ self,
46
+ embedding_dim: int,
47
+ num_embeddings: Optional[int] = None,
48
+ output_dim: Optional[int] = None,
49
+ norm_elementwise_affine: bool = False,
50
+ norm_eps: float = 1e-5,
51
+ chunk_dim: int = 0,
52
+ ):
38
53
  super().__init__()
39
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
54
+
55
+ self.chunk_dim = chunk_dim
56
+ output_dim = output_dim or embedding_dim * 2
57
+
58
+ if num_embeddings is not None:
59
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
60
+ else:
61
+ self.emb = None
62
+
40
63
  self.silu = nn.SiLU()
41
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
42
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
64
+ self.linear = nn.Linear(embedding_dim, output_dim)
65
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
66
+
67
+ def forward(
68
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
69
+ ) -> torch.Tensor:
70
+ if self.emb is not None:
71
+ temb = self.emb(timestep)
72
+
73
+ temb = self.linear(self.silu(temb))
74
+
75
+ if self.chunk_dim == 1:
76
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
77
+ # other if-branch. This branch is specific to CogVideoX for now.
78
+ shift, scale = temb.chunk(2, dim=1)
79
+ shift = shift[:, None, :]
80
+ scale = scale[:, None, :]
81
+ else:
82
+ scale, shift = temb.chunk(2, dim=0)
43
83
 
44
- def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
45
- emb = self.linear(self.silu(self.emb(timestep)))
46
- scale, shift = torch.chunk(emb, 2)
47
84
  x = self.norm(x) * (1 + scale) + shift
48
85
  return x
49
86
 
50
87
 
88
+ class FP32LayerNorm(nn.LayerNorm):
89
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
90
+ origin_dtype = inputs.dtype
91
+ return F.layer_norm(
92
+ inputs.float(),
93
+ self.normalized_shape,
94
+ self.weight.float() if self.weight is not None else None,
95
+ self.bias.float() if self.bias is not None else None,
96
+ self.eps,
97
+ ).to(origin_dtype)
98
+
99
+
51
100
  class AdaLayerNormZero(nn.Module):
52
101
  r"""
53
102
  Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -57,7 +106,7 @@ class AdaLayerNormZero(nn.Module):
57
106
  num_embeddings (`int`): The size of the embeddings dictionary.
58
107
  """
59
108
 
60
- def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
109
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
61
110
  super().__init__()
62
111
  if num_embeddings is not None:
63
112
  self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
@@ -65,8 +114,15 @@ class AdaLayerNormZero(nn.Module):
65
114
  self.emb = None
66
115
 
67
116
  self.silu = nn.SiLU()
68
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
69
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
117
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
118
+ if norm_type == "layer_norm":
119
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
120
+ elif norm_type == "fp32_layer_norm":
121
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
122
+ else:
123
+ raise ValueError(
124
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
125
+ )
70
126
 
71
127
  def forward(
72
128
  self,
@@ -84,6 +140,69 @@ class AdaLayerNormZero(nn.Module):
84
140
  return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
85
141
 
86
142
 
143
+ class AdaLayerNormZeroSingle(nn.Module):
144
+ r"""
145
+ Norm layer adaptive layer norm zero (adaLN-Zero).
146
+
147
+ Parameters:
148
+ embedding_dim (`int`): The size of each embedding vector.
149
+ num_embeddings (`int`): The size of the embeddings dictionary.
150
+ """
151
+
152
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
153
+ super().__init__()
154
+
155
+ self.silu = nn.SiLU()
156
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
157
+ if norm_type == "layer_norm":
158
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
159
+ else:
160
+ raise ValueError(
161
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
162
+ )
163
+
164
+ def forward(
165
+ self,
166
+ x: torch.Tensor,
167
+ emb: Optional[torch.Tensor] = None,
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ emb = self.linear(self.silu(emb))
170
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
171
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
172
+ return x, gate_msa
173
+
174
+
175
+ class LuminaRMSNormZero(nn.Module):
176
+ """
177
+ Norm layer adaptive RMS normalization zero.
178
+
179
+ Parameters:
180
+ embedding_dim (`int`): The size of each embedding vector.
181
+ """
182
+
183
+ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
184
+ super().__init__()
185
+ self.silu = nn.SiLU()
186
+ self.linear = nn.Linear(
187
+ min(embedding_dim, 1024),
188
+ 4 * embedding_dim,
189
+ bias=True,
190
+ )
191
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
192
+
193
+ def forward(
194
+ self,
195
+ x: torch.Tensor,
196
+ emb: Optional[torch.Tensor] = None,
197
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
198
+ # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
199
+ emb = self.linear(self.silu(emb))
200
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
201
+ x = self.norm(x) * (1 + scale_msa[:, None])
202
+
203
+ return x, gate_msa, scale_mlp, gate_mlp
204
+
205
+
87
206
  class AdaLayerNormSingle(nn.Module):
88
207
  r"""
89
208
  Norm layer adaptive layer norm single (adaLN-single).
@@ -188,6 +307,78 @@ class AdaLayerNormContinuous(nn.Module):
188
307
  return x
189
308
 
190
309
 
310
+ class LuminaLayerNormContinuous(nn.Module):
311
+ def __init__(
312
+ self,
313
+ embedding_dim: int,
314
+ conditioning_embedding_dim: int,
315
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
316
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
317
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
318
+ # However, this is how it was implemented in the original code, and it's rather likely you should
319
+ # set `elementwise_affine` to False.
320
+ elementwise_affine=True,
321
+ eps=1e-5,
322
+ bias=True,
323
+ norm_type="layer_norm",
324
+ out_dim: Optional[int] = None,
325
+ ):
326
+ super().__init__()
327
+ # AdaLN
328
+ self.silu = nn.SiLU()
329
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
330
+ if norm_type == "layer_norm":
331
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
332
+ else:
333
+ raise ValueError(f"unknown norm_type {norm_type}")
334
+ # linear_2
335
+ if out_dim is not None:
336
+ self.linear_2 = nn.Linear(
337
+ embedding_dim,
338
+ out_dim,
339
+ bias=bias,
340
+ )
341
+
342
+ def forward(
343
+ self,
344
+ x: torch.Tensor,
345
+ conditioning_embedding: torch.Tensor,
346
+ ) -> torch.Tensor:
347
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
348
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
349
+ scale = emb
350
+ x = self.norm(x) * (1 + scale)[:, None, :]
351
+
352
+ if self.linear_2 is not None:
353
+ x = self.linear_2(x)
354
+
355
+ return x
356
+
357
+
358
+ class CogVideoXLayerNormZero(nn.Module):
359
+ def __init__(
360
+ self,
361
+ conditioning_dim: int,
362
+ embedding_dim: int,
363
+ elementwise_affine: bool = True,
364
+ eps: float = 1e-5,
365
+ bias: bool = True,
366
+ ) -> None:
367
+ super().__init__()
368
+
369
+ self.silu = nn.SiLU()
370
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
371
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
372
+
373
+ def forward(
374
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
375
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
376
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
377
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
378
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
379
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
380
+
381
+
191
382
  if is_torch_version(">=", "2.1.0"):
192
383
  LayerNorm = nn.LayerNorm
193
384
  else:
@@ -2,12 +2,18 @@ from ...utils import is_torch_available
2
2
 
3
3
 
4
4
  if is_torch_available():
5
+ from .auraflow_transformer_2d import AuraFlowTransformer2DModel
6
+ from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
5
7
  from .dit_transformer_2d import DiTTransformer2DModel
6
8
  from .dual_transformer_2d import DualTransformer2DModel
7
9
  from .hunyuan_transformer_2d import HunyuanDiT2DModel
10
+ from .latte_transformer_3d import LatteTransformer3DModel
11
+ from .lumina_nextdit2d import LuminaNextDiT2DModel
8
12
  from .pixart_transformer_2d import PixArtTransformer2DModel
9
13
  from .prior_transformer import PriorTransformer
14
+ from .stable_audio_transformer import StableAudioDiTModel
10
15
  from .t5_film_transformer import T5FilmDecoder
11
16
  from .transformer_2d import Transformer2DModel
17
+ from .transformer_flux import FluxTransformer2DModel
12
18
  from .transformer_sd3 import SD3Transformer2DModel
13
19
  from .transformer_temporal import TransformerTemporalModel