diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -28,22 +28,32 @@ if is_torch_available():
28
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
29
29
  _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
30
30
  _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
31
32
  _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
33
+ _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
32
34
  _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
33
35
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34
36
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
35
37
  _import_structure["controlnet"] = ["ControlNetModel"]
38
+ _import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
36
39
  _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
40
+ _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
37
41
  _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
38
42
  _import_structure["embeddings"] = ["ImageProjection"]
39
43
  _import_structure["modeling_utils"] = ["ModelMixin"]
44
+ _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
45
+ _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
40
46
  _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
41
47
  _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
42
48
  _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
49
+ _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
50
+ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
43
51
  _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
44
52
  _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
53
+ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
45
54
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
46
55
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
56
+ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
47
57
  _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
48
58
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
49
59
  _import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -69,23 +79,33 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
69
79
  from .autoencoders import (
70
80
  AsymmetricAutoencoderKL,
71
81
  AutoencoderKL,
82
+ AutoencoderKLCogVideoX,
72
83
  AutoencoderKLTemporalDecoder,
84
+ AutoencoderOobleck,
73
85
  AutoencoderTiny,
74
86
  ConsistencyDecoderVAE,
75
87
  VQModel,
76
88
  )
77
89
  from .controlnet import ControlNetModel
90
+ from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
78
91
  from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
92
+ from .controlnet_sparsectrl import SparseControlNetModel
79
93
  from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
80
94
  from .embeddings import ImageProjection
81
95
  from .modeling_utils import ModelMixin
82
96
  from .transformers import (
97
+ AuraFlowTransformer2DModel,
98
+ CogVideoXTransformer3DModel,
83
99
  DiTTransformer2DModel,
84
100
  DualTransformer2DModel,
101
+ FluxTransformer2DModel,
85
102
  HunyuanDiT2DModel,
103
+ LatteTransformer3DModel,
104
+ LuminaNextDiT2DModel,
86
105
  PixArtTransformer2DModel,
87
106
  PriorTransformer,
88
107
  SD3Transformer2DModel,
108
+ StableAudioDiTModel,
89
109
  T5FilmDecoder,
90
110
  Transformer2DModel,
91
111
  TransformerTemporalModel,
@@ -123,6 +123,28 @@ class GEGLU(nn.Module):
123
123
  return hidden_states * self.gelu(gate)
124
124
 
125
125
 
126
+ class SwiGLU(nn.Module):
127
+ r"""
128
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
129
+ but uses SiLU / Swish instead of GeLU.
130
+
131
+ Parameters:
132
+ dim_in (`int`): The number of channels in the input.
133
+ dim_out (`int`): The number of channels in the output.
134
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
135
+ """
136
+
137
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
138
+ super().__init__()
139
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
140
+ self.activation = nn.SiLU()
141
+
142
+ def forward(self, hidden_states):
143
+ hidden_states = self.proj(hidden_states)
144
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
145
+ return hidden_states * self.activation(gate)
146
+
147
+
126
148
  class ApproximateGELU(nn.Module):
127
149
  r"""
128
150
  The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, Optional
14
+ from typing import Any, Dict, List, Optional, Tuple
15
15
 
16
16
  import torch
17
17
  import torch.nn.functional as F
@@ -19,7 +19,7 @@ from torch import nn
19
19
 
20
20
  from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
- from .activations import GEGLU, GELU, ApproximateGELU
22
+ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
23
  from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
25
  from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
128
128
  query_dim=dim,
129
129
  cross_attention_dim=None,
130
130
  added_kv_proj_dim=dim,
131
- dim_head=attention_head_dim // num_attention_heads,
131
+ dim_head=attention_head_dim,
132
132
  heads=num_attention_heads,
133
- out_dim=attention_head_dim,
133
+ out_dim=dim,
134
134
  context_pre_only=context_pre_only,
135
135
  bias=True,
136
136
  processor=processor,
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
272
272
  attention_out_bias: bool = True,
273
273
  ):
274
274
  super().__init__()
275
+ self.dim = dim
276
+ self.num_attention_heads = num_attention_heads
277
+ self.attention_head_dim = attention_head_dim
278
+ self.dropout = dropout
279
+ self.cross_attention_dim = cross_attention_dim
280
+ self.activation_fn = activation_fn
281
+ self.attention_bias = attention_bias
282
+ self.double_self_attention = double_self_attention
283
+ self.norm_elementwise_affine = norm_elementwise_affine
284
+ self.positional_embeddings = positional_embeddings
285
+ self.num_positional_embeddings = num_positional_embeddings
275
286
  self.only_cross_attention = only_cross_attention
276
287
 
277
288
  # We keep these boolean flags for backward-compatibility.
@@ -359,7 +370,10 @@ class BasicTransformerBlock(nn.Module):
359
370
  out_bias=attention_out_bias,
360
371
  ) # is self-attn if encoder_hidden_states is none
361
372
  else:
362
- self.norm2 = None
373
+ if norm_type == "ada_norm_single": # For Latte
374
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
375
+ else:
376
+ self.norm2 = None
363
377
  self.attn2 = None
364
378
 
365
379
  # 3. Feed-forward
@@ -373,7 +387,7 @@ class BasicTransformerBlock(nn.Module):
373
387
  "layer_norm",
374
388
  )
375
389
 
376
- elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
390
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
377
391
  self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
378
392
  elif norm_type == "layer_norm_i2vgen":
379
393
  self.norm3 = None
@@ -439,7 +453,6 @@ class BasicTransformerBlock(nn.Module):
439
453
  ).chunk(6, dim=1)
440
454
  norm_hidden_states = self.norm1(hidden_states)
441
455
  norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
442
- norm_hidden_states = norm_hidden_states.squeeze(1)
443
456
  else:
444
457
  raise ValueError("Incorrect norm used")
445
458
 
@@ -456,6 +469,7 @@ class BasicTransformerBlock(nn.Module):
456
469
  attention_mask=attention_mask,
457
470
  **cross_attention_kwargs,
458
471
  )
472
+
459
473
  if self.norm_type == "ada_norm_zero":
460
474
  attn_output = gate_msa.unsqueeze(1) * attn_output
461
475
  elif self.norm_type == "ada_norm_single":
@@ -527,6 +541,56 @@ class BasicTransformerBlock(nn.Module):
527
541
  return hidden_states
528
542
 
529
543
 
544
+ class LuminaFeedForward(nn.Module):
545
+ r"""
546
+ A feed-forward layer.
547
+
548
+ Parameters:
549
+ hidden_size (`int`):
550
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
551
+ hidden representations.
552
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
553
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
554
+ of this value.
555
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
556
+ dimension. Defaults to None.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ dim: int,
562
+ inner_dim: int,
563
+ multiple_of: Optional[int] = 256,
564
+ ffn_dim_multiplier: Optional[float] = None,
565
+ ):
566
+ super().__init__()
567
+ inner_dim = int(2 * inner_dim / 3)
568
+ # custom hidden_size factor multiplier
569
+ if ffn_dim_multiplier is not None:
570
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
571
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
572
+
573
+ self.linear_1 = nn.Linear(
574
+ dim,
575
+ inner_dim,
576
+ bias=False,
577
+ )
578
+ self.linear_2 = nn.Linear(
579
+ inner_dim,
580
+ dim,
581
+ bias=False,
582
+ )
583
+ self.linear_3 = nn.Linear(
584
+ dim,
585
+ inner_dim,
586
+ bias=False,
587
+ )
588
+ self.silu = FP32SiLU()
589
+
590
+ def forward(self, x):
591
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
592
+
593
+
530
594
  @maybe_allow_in_graph
531
595
  class TemporalBasicTransformerBlock(nn.Module):
532
596
  r"""
@@ -729,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
729
793
  return hidden_states
730
794
 
731
795
 
796
+ @maybe_allow_in_graph
797
+ class FreeNoiseTransformerBlock(nn.Module):
798
+ r"""
799
+ A FreeNoise Transformer block.
800
+
801
+ Parameters:
802
+ dim (`int`):
803
+ The number of channels in the input and output.
804
+ num_attention_heads (`int`):
805
+ The number of heads to use for multi-head attention.
806
+ attention_head_dim (`int`):
807
+ The number of channels in each head.
808
+ dropout (`float`, *optional*, defaults to 0.0):
809
+ The dropout probability to use.
810
+ cross_attention_dim (`int`, *optional*):
811
+ The size of the encoder_hidden_states vector for cross attention.
812
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
813
+ Activation function to be used in feed-forward.
814
+ num_embeds_ada_norm (`int`, *optional*):
815
+ The number of diffusion steps used during training. See `Transformer2DModel`.
816
+ attention_bias (`bool`, defaults to `False`):
817
+ Configure if the attentions should contain a bias parameter.
818
+ only_cross_attention (`bool`, defaults to `False`):
819
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
820
+ double_self_attention (`bool`, defaults to `False`):
821
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
822
+ upcast_attention (`bool`, defaults to `False`):
823
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
824
+ norm_elementwise_affine (`bool`, defaults to `True`):
825
+ Whether to use learnable elementwise affine parameters for normalization.
826
+ norm_type (`str`, defaults to `"layer_norm"`):
827
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
828
+ final_dropout (`bool` defaults to `False`):
829
+ Whether to apply a final dropout after the last feed-forward layer.
830
+ attention_type (`str`, defaults to `"default"`):
831
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
832
+ positional_embeddings (`str`, *optional*):
833
+ The type of positional embeddings to apply to.
834
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
835
+ The maximum number of positional embeddings to apply.
836
+ ff_inner_dim (`int`, *optional*):
837
+ Hidden dimension of feed-forward MLP.
838
+ ff_bias (`bool`, defaults to `True`):
839
+ Whether or not to use bias in feed-forward MLP.
840
+ attention_out_bias (`bool`, defaults to `True`):
841
+ Whether or not to use bias in attention output project layer.
842
+ context_length (`int`, defaults to `16`):
843
+ The maximum number of frames that the FreeNoise block processes at once.
844
+ context_stride (`int`, defaults to `4`):
845
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
846
+ weighting_scheme (`str`, defaults to `"pyramid"`):
847
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
848
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
849
+ used.
850
+ """
851
+
852
+ def __init__(
853
+ self,
854
+ dim: int,
855
+ num_attention_heads: int,
856
+ attention_head_dim: int,
857
+ dropout: float = 0.0,
858
+ cross_attention_dim: Optional[int] = None,
859
+ activation_fn: str = "geglu",
860
+ num_embeds_ada_norm: Optional[int] = None,
861
+ attention_bias: bool = False,
862
+ only_cross_attention: bool = False,
863
+ double_self_attention: bool = False,
864
+ upcast_attention: bool = False,
865
+ norm_elementwise_affine: bool = True,
866
+ norm_type: str = "layer_norm",
867
+ norm_eps: float = 1e-5,
868
+ final_dropout: bool = False,
869
+ positional_embeddings: Optional[str] = None,
870
+ num_positional_embeddings: Optional[int] = None,
871
+ ff_inner_dim: Optional[int] = None,
872
+ ff_bias: bool = True,
873
+ attention_out_bias: bool = True,
874
+ context_length: int = 16,
875
+ context_stride: int = 4,
876
+ weighting_scheme: str = "pyramid",
877
+ ):
878
+ super().__init__()
879
+ self.dim = dim
880
+ self.num_attention_heads = num_attention_heads
881
+ self.attention_head_dim = attention_head_dim
882
+ self.dropout = dropout
883
+ self.cross_attention_dim = cross_attention_dim
884
+ self.activation_fn = activation_fn
885
+ self.attention_bias = attention_bias
886
+ self.double_self_attention = double_self_attention
887
+ self.norm_elementwise_affine = norm_elementwise_affine
888
+ self.positional_embeddings = positional_embeddings
889
+ self.num_positional_embeddings = num_positional_embeddings
890
+ self.only_cross_attention = only_cross_attention
891
+
892
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
893
+
894
+ # We keep these boolean flags for backward-compatibility.
895
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
896
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
897
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
898
+ self.use_layer_norm = norm_type == "layer_norm"
899
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
900
+
901
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
902
+ raise ValueError(
903
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
904
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
905
+ )
906
+
907
+ self.norm_type = norm_type
908
+ self.num_embeds_ada_norm = num_embeds_ada_norm
909
+
910
+ if positional_embeddings and (num_positional_embeddings is None):
911
+ raise ValueError(
912
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
913
+ )
914
+
915
+ if positional_embeddings == "sinusoidal":
916
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
917
+ else:
918
+ self.pos_embed = None
919
+
920
+ # Define 3 blocks. Each block has its own normalization layer.
921
+ # 1. Self-Attn
922
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
923
+
924
+ self.attn1 = Attention(
925
+ query_dim=dim,
926
+ heads=num_attention_heads,
927
+ dim_head=attention_head_dim,
928
+ dropout=dropout,
929
+ bias=attention_bias,
930
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
931
+ upcast_attention=upcast_attention,
932
+ out_bias=attention_out_bias,
933
+ )
934
+
935
+ # 2. Cross-Attn
936
+ if cross_attention_dim is not None or double_self_attention:
937
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
938
+
939
+ self.attn2 = Attention(
940
+ query_dim=dim,
941
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
942
+ heads=num_attention_heads,
943
+ dim_head=attention_head_dim,
944
+ dropout=dropout,
945
+ bias=attention_bias,
946
+ upcast_attention=upcast_attention,
947
+ out_bias=attention_out_bias,
948
+ ) # is self-attn if encoder_hidden_states is none
949
+
950
+ # 3. Feed-forward
951
+ self.ff = FeedForward(
952
+ dim,
953
+ dropout=dropout,
954
+ activation_fn=activation_fn,
955
+ final_dropout=final_dropout,
956
+ inner_dim=ff_inner_dim,
957
+ bias=ff_bias,
958
+ )
959
+
960
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
961
+
962
+ # let chunk size default to None
963
+ self._chunk_size = None
964
+ self._chunk_dim = 0
965
+
966
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
967
+ frame_indices = []
968
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
969
+ window_start = i
970
+ window_end = min(num_frames, i + self.context_length)
971
+ frame_indices.append((window_start, window_end))
972
+ return frame_indices
973
+
974
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
975
+ if weighting_scheme == "pyramid":
976
+ if num_frames % 2 == 0:
977
+ # num_frames = 4 => [1, 2, 2, 1]
978
+ weights = list(range(1, num_frames // 2 + 1))
979
+ weights = weights + weights[::-1]
980
+ else:
981
+ # num_frames = 5 => [1, 2, 3, 2, 1]
982
+ weights = list(range(1, num_frames // 2 + 1))
983
+ weights = weights + [num_frames // 2 + 1] + weights[::-1]
984
+ else:
985
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
986
+
987
+ return weights
988
+
989
+ def set_free_noise_properties(
990
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
991
+ ) -> None:
992
+ self.context_length = context_length
993
+ self.context_stride = context_stride
994
+ self.weighting_scheme = weighting_scheme
995
+
996
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
997
+ # Sets chunk feed-forward
998
+ self._chunk_size = chunk_size
999
+ self._chunk_dim = dim
1000
+
1001
+ def forward(
1002
+ self,
1003
+ hidden_states: torch.Tensor,
1004
+ attention_mask: Optional[torch.Tensor] = None,
1005
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1006
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1007
+ cross_attention_kwargs: Dict[str, Any] = None,
1008
+ *args,
1009
+ **kwargs,
1010
+ ) -> torch.Tensor:
1011
+ if cross_attention_kwargs is not None:
1012
+ if cross_attention_kwargs.get("scale", None) is not None:
1013
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1014
+
1015
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1016
+
1017
+ # hidden_states: [B x H x W, F, C]
1018
+ device = hidden_states.device
1019
+ dtype = hidden_states.dtype
1020
+
1021
+ num_frames = hidden_states.size(1)
1022
+ frame_indices = self._get_frame_indices(num_frames)
1023
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1024
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1025
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1026
+
1027
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1028
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1029
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1030
+ if not is_last_frame_batch_complete:
1031
+ if num_frames < self.context_length:
1032
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1033
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1034
+ frame_indices.append((num_frames - self.context_length, num_frames))
1035
+
1036
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1037
+ accumulated_values = torch.zeros_like(hidden_states)
1038
+
1039
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1040
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1041
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1042
+ # essentially a non-multiple of `context_length`.
1043
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1044
+ weights *= frame_weights
1045
+
1046
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1047
+
1048
+ # Notice that normalization is always applied before the real computation in the following blocks.
1049
+ # 1. Self-Attention
1050
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1051
+
1052
+ if self.pos_embed is not None:
1053
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1054
+
1055
+ attn_output = self.attn1(
1056
+ norm_hidden_states,
1057
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1058
+ attention_mask=attention_mask,
1059
+ **cross_attention_kwargs,
1060
+ )
1061
+
1062
+ hidden_states_chunk = attn_output + hidden_states_chunk
1063
+ if hidden_states_chunk.ndim == 4:
1064
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1065
+
1066
+ # 2. Cross-Attention
1067
+ if self.attn2 is not None:
1068
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1069
+
1070
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1071
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1072
+
1073
+ attn_output = self.attn2(
1074
+ norm_hidden_states,
1075
+ encoder_hidden_states=encoder_hidden_states,
1076
+ attention_mask=encoder_attention_mask,
1077
+ **cross_attention_kwargs,
1078
+ )
1079
+ hidden_states_chunk = attn_output + hidden_states_chunk
1080
+
1081
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1082
+ accumulated_values[:, -last_frame_batch_length:] += (
1083
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1084
+ )
1085
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1086
+ else:
1087
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1088
+ num_times_accumulated[:, frame_start:frame_end] += weights
1089
+
1090
+ hidden_states = torch.where(
1091
+ num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1092
+ ).to(dtype)
1093
+
1094
+ # 3. Feed-forward
1095
+ norm_hidden_states = self.norm3(hidden_states)
1096
+
1097
+ if self._chunk_size is not None:
1098
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1099
+ else:
1100
+ ff_output = self.ff(norm_hidden_states)
1101
+
1102
+ hidden_states = ff_output + hidden_states
1103
+ if hidden_states.ndim == 4:
1104
+ hidden_states = hidden_states.squeeze(1)
1105
+
1106
+ return hidden_states
1107
+
1108
+
732
1109
  class FeedForward(nn.Module):
733
1110
  r"""
734
1111
  A feed-forward layer.
@@ -767,6 +1144,8 @@ class FeedForward(nn.Module):
767
1144
  act_fn = GEGLU(dim, inner_dim, bias=bias)
768
1145
  elif activation_fn == "geglu-approximate":
769
1146
  act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1147
+ elif activation_fn == "swiglu":
1148
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
770
1149
 
771
1150
  self.net = nn.ModuleList([])
772
1151
  # project in