diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
1
+ # Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from ...configuration_utils import ConfigMixin, register_to_config
20
+ from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
21
+ from ..attention import BasicTransformerBlock
22
+ from ..embeddings import PatchEmbed
23
+ from ..modeling_outputs import Transformer2DModelOutput
24
+ from ..modeling_utils import ModelMixin
25
+ from ..normalization import AdaLayerNormSingle
26
+
27
+
28
+ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
29
+ _supports_gradient_checkpointing = True
30
+
31
+ """
32
+ A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
33
+ https://github.com/Vchitect/Latte
34
+
35
+ Parameters:
36
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
38
+ in_channels (`int`, *optional*):
39
+ The number of channels in the input.
40
+ out_channels (`int`, *optional*):
41
+ The number of channels in the output.
42
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
43
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
44
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
45
+ attention_bias (`bool`, *optional*):
46
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
47
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
48
+ This is fixed during training since it is used to learn a number of position embeddings.
49
+ patch_size (`int`, *optional*):
50
+ The size of the patches to use in the patch embedding layer.
51
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
52
+ num_embeds_ada_norm ( `int`, *optional*):
53
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
54
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
55
+ added to the hidden states. During inference, you can denoise for up to but not more steps than
56
+ `num_embeds_ada_norm`.
57
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
58
+ The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
59
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
60
+ Whether or not to use elementwise affine in normalization layers.
61
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
62
+ caption_channels (`int`, *optional*):
63
+ The number of channels in the caption embeddings.
64
+ video_length (`int`, *optional*):
65
+ The number of frames in the video-like data.
66
+ """
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ num_attention_heads: int = 16,
72
+ attention_head_dim: int = 88,
73
+ in_channels: Optional[int] = None,
74
+ out_channels: Optional[int] = None,
75
+ num_layers: int = 1,
76
+ dropout: float = 0.0,
77
+ cross_attention_dim: Optional[int] = None,
78
+ attention_bias: bool = False,
79
+ sample_size: int = 64,
80
+ patch_size: Optional[int] = None,
81
+ activation_fn: str = "geglu",
82
+ num_embeds_ada_norm: Optional[int] = None,
83
+ norm_type: str = "layer_norm",
84
+ norm_elementwise_affine: bool = True,
85
+ norm_eps: float = 1e-5,
86
+ caption_channels: int = None,
87
+ video_length: int = 16,
88
+ ):
89
+ super().__init__()
90
+ inner_dim = num_attention_heads * attention_head_dim
91
+
92
+ # 1. Define input layers
93
+ self.height = sample_size
94
+ self.width = sample_size
95
+
96
+ interpolation_scale = self.config.sample_size // 64
97
+ interpolation_scale = max(interpolation_scale, 1)
98
+ self.pos_embed = PatchEmbed(
99
+ height=sample_size,
100
+ width=sample_size,
101
+ patch_size=patch_size,
102
+ in_channels=in_channels,
103
+ embed_dim=inner_dim,
104
+ interpolation_scale=interpolation_scale,
105
+ )
106
+
107
+ # 2. Define spatial transformers blocks
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ BasicTransformerBlock(
111
+ inner_dim,
112
+ num_attention_heads,
113
+ attention_head_dim,
114
+ dropout=dropout,
115
+ cross_attention_dim=cross_attention_dim,
116
+ activation_fn=activation_fn,
117
+ num_embeds_ada_norm=num_embeds_ada_norm,
118
+ attention_bias=attention_bias,
119
+ norm_type=norm_type,
120
+ norm_elementwise_affine=norm_elementwise_affine,
121
+ norm_eps=norm_eps,
122
+ )
123
+ for d in range(num_layers)
124
+ ]
125
+ )
126
+
127
+ # 3. Define temporal transformers blocks
128
+ self.temporal_transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(
131
+ inner_dim,
132
+ num_attention_heads,
133
+ attention_head_dim,
134
+ dropout=dropout,
135
+ cross_attention_dim=None,
136
+ activation_fn=activation_fn,
137
+ num_embeds_ada_norm=num_embeds_ada_norm,
138
+ attention_bias=attention_bias,
139
+ norm_type=norm_type,
140
+ norm_elementwise_affine=norm_elementwise_affine,
141
+ norm_eps=norm_eps,
142
+ )
143
+ for d in range(num_layers)
144
+ ]
145
+ )
146
+
147
+ # 4. Define output layers
148
+ self.out_channels = in_channels if out_channels is None else out_channels
149
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
150
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
151
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
152
+
153
+ # 5. Latte other blocks.
154
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
155
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
156
+
157
+ # define temporal positional embedding
158
+ temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
159
+ inner_dim, torch.arange(0, video_length).unsqueeze(1)
160
+ ) # 1152 hidden size
161
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
162
+
163
+ self.gradient_checkpointing = False
164
+
165
+ def _set_gradient_checkpointing(self, module, value=False):
166
+ self.gradient_checkpointing = value
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ timestep: Optional[torch.LongTensor] = None,
172
+ encoder_hidden_states: Optional[torch.Tensor] = None,
173
+ encoder_attention_mask: Optional[torch.Tensor] = None,
174
+ enable_temporal_attentions: bool = True,
175
+ return_dict: bool = True,
176
+ ):
177
+ """
178
+ The [`LatteTransformer3DModel`] forward method.
179
+
180
+ Args:
181
+ hidden_states shape `(batch size, channel, num_frame, height, width)`:
182
+ Input `hidden_states`.
183
+ timestep ( `torch.LongTensor`, *optional*):
184
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
185
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
186
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
187
+ self-attention.
188
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
189
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
190
+
191
+ * Mask `(batcheight, sequence_length)` True = keep, False = discard.
192
+ * Bias `(batcheight, 1, sequence_length)` 0 = keep, -10000 = discard.
193
+
194
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
195
+ above. This bias will be added to the cross-attention scores.
196
+ enable_temporal_attentions:
197
+ (`bool`, *optional*, defaults to `True`): Whether to enable temporal attentions.
198
+ return_dict (`bool`, *optional*, defaults to `True`):
199
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
200
+ tuple.
201
+
202
+ Returns:
203
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
204
+ `tuple` where the first element is the sample tensor.
205
+ """
206
+
207
+ # Reshape hidden states
208
+ batch_size, channels, num_frame, height, width = hidden_states.shape
209
+ # batch_size channels num_frame height width -> (batch_size * num_frame) channels height width
210
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
211
+
212
+ # Input
213
+ height, width = (
214
+ hidden_states.shape[-2] // self.config.patch_size,
215
+ hidden_states.shape[-1] // self.config.patch_size,
216
+ )
217
+ num_patches = height * width
218
+
219
+ hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
220
+
221
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
222
+ timestep, embedded_timestep = self.adaln_single(
223
+ timestep, added_cond_kwargs=added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
224
+ )
225
+
226
+ # Prepare text embeddings for spatial block
227
+ # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
228
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
229
+ encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
230
+ -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
231
+ )
232
+
233
+ # Prepare timesteps for spatial and temporal block
234
+ timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
235
+ timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
236
+
237
+ # Spatial and temporal transformer blocks
238
+ for i, (spatial_block, temp_block) in enumerate(
239
+ zip(self.transformer_blocks, self.temporal_transformer_blocks)
240
+ ):
241
+ if self.training and self.gradient_checkpointing:
242
+ hidden_states = torch.utils.checkpoint.checkpoint(
243
+ spatial_block,
244
+ hidden_states,
245
+ None, # attention_mask
246
+ encoder_hidden_states_spatial,
247
+ encoder_attention_mask,
248
+ timestep_spatial,
249
+ None, # cross_attention_kwargs
250
+ None, # class_labels
251
+ use_reentrant=False,
252
+ )
253
+ else:
254
+ hidden_states = spatial_block(
255
+ hidden_states,
256
+ None, # attention_mask
257
+ encoder_hidden_states_spatial,
258
+ encoder_attention_mask,
259
+ timestep_spatial,
260
+ None, # cross_attention_kwargs
261
+ None, # class_labels
262
+ )
263
+
264
+ if enable_temporal_attentions:
265
+ # (batch_size * num_frame) num_tokens hidden_size -> (batch_size * num_tokens) num_frame hidden_size
266
+ hidden_states = hidden_states.reshape(
267
+ batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
268
+ ).permute(0, 2, 1, 3)
269
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
270
+
271
+ if i == 0 and num_frame > 1:
272
+ hidden_states = hidden_states + self.temp_pos_embed
273
+
274
+ if self.training and self.gradient_checkpointing:
275
+ hidden_states = torch.utils.checkpoint.checkpoint(
276
+ temp_block,
277
+ hidden_states,
278
+ None, # attention_mask
279
+ None, # encoder_hidden_states
280
+ None, # encoder_attention_mask
281
+ timestep_temp,
282
+ None, # cross_attention_kwargs
283
+ None, # class_labels
284
+ use_reentrant=False,
285
+ )
286
+ else:
287
+ hidden_states = temp_block(
288
+ hidden_states,
289
+ None, # attention_mask
290
+ None, # encoder_hidden_states
291
+ None, # encoder_attention_mask
292
+ timestep_temp,
293
+ None, # cross_attention_kwargs
294
+ None, # class_labels
295
+ )
296
+
297
+ # (batch_size * num_tokens) num_frame hidden_size -> (batch_size * num_frame) num_tokens hidden_size
298
+ hidden_states = hidden_states.reshape(
299
+ batch_size, -1, hidden_states.shape[-2], hidden_states.shape[-1]
300
+ ).permute(0, 2, 1, 3)
301
+ hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
302
+
303
+ embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
304
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
305
+ hidden_states = self.norm_out(hidden_states)
306
+ # Modulation
307
+ hidden_states = hidden_states * (1 + scale) + shift
308
+ hidden_states = self.proj_out(hidden_states)
309
+
310
+ # unpatchify
311
+ if self.adaln_single is None:
312
+ height = width = int(hidden_states.shape[1] ** 0.5)
313
+ hidden_states = hidden_states.reshape(
314
+ shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
315
+ )
316
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
317
+ output = hidden_states.reshape(
318
+ shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
319
+ )
320
+ output = output.reshape(batch_size, -1, output.shape[-3], output.shape[-2], output.shape[-1]).permute(
321
+ 0, 2, 1, 3, 4
322
+ )
323
+
324
+ if not return_dict:
325
+ return (output,)
326
+
327
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,340 @@
1
+ # Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import logging
22
+ from ..attention import LuminaFeedForward
23
+ from ..attention_processor import Attention, LuminaAttnProcessor2_0
24
+ from ..embeddings import (
25
+ LuminaCombinedTimestepCaptionEmbedding,
26
+ LuminaPatchEmbed,
27
+ )
28
+ from ..modeling_outputs import Transformer2DModelOutput
29
+ from ..modeling_utils import ModelMixin
30
+ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ class LuminaNextDiTBlock(nn.Module):
37
+ """
38
+ A LuminaNextDiTBlock for LuminaNextDiT2DModel.
39
+
40
+ Parameters:
41
+ dim (`int`): Embedding dimension of the input features.
42
+ num_attention_heads (`int`): Number of attention heads.
43
+ num_kv_heads (`int`):
44
+ Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
45
+ multiple_of (`int`): The number of multiple of ffn layer.
46
+ ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
47
+ norm_eps (`float`): The eps for norm layer.
48
+ qk_norm (`bool`): normalization for query and key.
49
+ cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
50
+ norm_elementwise_affine (`bool`, *optional*, defaults to True),
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ dim: int,
56
+ num_attention_heads: int,
57
+ num_kv_heads: int,
58
+ multiple_of: int,
59
+ ffn_dim_multiplier: float,
60
+ norm_eps: float,
61
+ qk_norm: bool,
62
+ cross_attention_dim: int,
63
+ norm_elementwise_affine: bool = True,
64
+ ) -> None:
65
+ super().__init__()
66
+ self.head_dim = dim // num_attention_heads
67
+
68
+ self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
69
+
70
+ # Self-attention
71
+ self.attn1 = Attention(
72
+ query_dim=dim,
73
+ cross_attention_dim=None,
74
+ dim_head=dim // num_attention_heads,
75
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
76
+ heads=num_attention_heads,
77
+ kv_heads=num_kv_heads,
78
+ eps=1e-5,
79
+ bias=False,
80
+ out_bias=False,
81
+ processor=LuminaAttnProcessor2_0(),
82
+ )
83
+ self.attn1.to_out = nn.Identity()
84
+
85
+ # Cross-attention
86
+ self.attn2 = Attention(
87
+ query_dim=dim,
88
+ cross_attention_dim=cross_attention_dim,
89
+ dim_head=dim // num_attention_heads,
90
+ qk_norm="layer_norm_across_heads" if qk_norm else None,
91
+ heads=num_attention_heads,
92
+ kv_heads=num_kv_heads,
93
+ eps=1e-5,
94
+ bias=False,
95
+ out_bias=False,
96
+ processor=LuminaAttnProcessor2_0(),
97
+ )
98
+
99
+ self.feed_forward = LuminaFeedForward(
100
+ dim=dim,
101
+ inner_dim=4 * dim,
102
+ multiple_of=multiple_of,
103
+ ffn_dim_multiplier=ffn_dim_multiplier,
104
+ )
105
+
106
+ self.norm1 = LuminaRMSNormZero(
107
+ embedding_dim=dim,
108
+ norm_eps=norm_eps,
109
+ norm_elementwise_affine=norm_elementwise_affine,
110
+ )
111
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
112
+
113
+ self.norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
114
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
115
+
116
+ self.norm1_context = RMSNorm(cross_attention_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
117
+
118
+ def forward(
119
+ self,
120
+ hidden_states: torch.Tensor,
121
+ attention_mask: torch.Tensor,
122
+ image_rotary_emb: torch.Tensor,
123
+ encoder_hidden_states: torch.Tensor,
124
+ encoder_mask: torch.Tensor,
125
+ temb: torch.Tensor,
126
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
127
+ ):
128
+ """
129
+ Perform a forward pass through the LuminaNextDiTBlock.
130
+
131
+ Parameters:
132
+ hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
133
+ attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
134
+ image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
135
+ encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder.
136
+ encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask.
137
+ temb (`torch.Tensor`): Timestep embedding with text prompt embedding.
138
+ cross_attention_kwargs (`Dict[str, Any]`): kwargs for cross attention.
139
+ """
140
+ residual = hidden_states
141
+
142
+ # Self-attention
143
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
144
+ self_attn_output = self.attn1(
145
+ hidden_states=norm_hidden_states,
146
+ encoder_hidden_states=norm_hidden_states,
147
+ attention_mask=attention_mask,
148
+ query_rotary_emb=image_rotary_emb,
149
+ key_rotary_emb=image_rotary_emb,
150
+ **cross_attention_kwargs,
151
+ )
152
+
153
+ # Cross-attention
154
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
155
+ cross_attn_output = self.attn2(
156
+ hidden_states=norm_hidden_states,
157
+ encoder_hidden_states=norm_encoder_hidden_states,
158
+ attention_mask=encoder_mask,
159
+ query_rotary_emb=image_rotary_emb,
160
+ key_rotary_emb=None,
161
+ **cross_attention_kwargs,
162
+ )
163
+ cross_attn_output = cross_attn_output * self.gate.tanh().view(1, 1, -1, 1)
164
+ mixed_attn_output = self_attn_output + cross_attn_output
165
+ mixed_attn_output = mixed_attn_output.flatten(-2)
166
+ # linear proj
167
+ hidden_states = self.attn2.to_out[0](mixed_attn_output)
168
+
169
+ hidden_states = residual + gate_msa.unsqueeze(1).tanh() * self.norm2(hidden_states)
170
+
171
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
172
+
173
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
174
+
175
+ return hidden_states
176
+
177
+
178
+ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
179
+ """
180
+ LuminaNextDiT: Diffusion model with a Transformer backbone.
181
+
182
+ Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
183
+
184
+ Parameters:
185
+ sample_size (`int`): The width of the latent images. This is fixed during training since
186
+ it is used to learn a number of position embeddings.
187
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
188
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
189
+ in_channels (`int`, *optional*, defaults to 4):
190
+ The number of input channels for the model. Typically, this matches the number of channels in the input
191
+ images.
192
+ hidden_size (`int`, *optional*, defaults to 4096):
193
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
194
+ hidden representations.
195
+ num_layers (`int`, *optional*, default to 32):
196
+ The number of layers in the model. This defines the depth of the neural network.
197
+ num_attention_heads (`int`, *optional*, defaults to 32):
198
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
199
+ mechanisms are used.
200
+ num_kv_heads (`int`, *optional*, defaults to 8):
201
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
202
+ If None, it defaults to num_attention_heads.
203
+ multiple_of (`int`, *optional*, defaults to 256):
204
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
205
+ configurations.
206
+ ffn_dim_multiplier (`float`, *optional*):
207
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
208
+ the model configuration.
209
+ norm_eps (`float`, *optional*, defaults to 1e-5):
210
+ A small value added to the denominator for numerical stability in normalization layers.
211
+ learn_sigma (`bool`, *optional*, defaults to True):
212
+ Whether the model should learn the sigma parameter, which might be related to uncertainty or variance in
213
+ predictions.
214
+ qk_norm (`bool`, *optional*, defaults to True):
215
+ Indicates if the queries and keys in the attention mechanism should be normalized.
216
+ cross_attention_dim (`int`, *optional*, defaults to 2048):
217
+ The dimensionality of the text embeddings. This parameter defines the size of the text representations used
218
+ in the model.
219
+ scaling_factor (`float`, *optional*, defaults to 1.0):
220
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
221
+ overall scale of the model's operations.
222
+ """
223
+
224
+ @register_to_config
225
+ def __init__(
226
+ self,
227
+ sample_size: int = 128,
228
+ patch_size: Optional[int] = 2,
229
+ in_channels: Optional[int] = 4,
230
+ hidden_size: Optional[int] = 2304,
231
+ num_layers: Optional[int] = 32,
232
+ num_attention_heads: Optional[int] = 32,
233
+ num_kv_heads: Optional[int] = None,
234
+ multiple_of: Optional[int] = 256,
235
+ ffn_dim_multiplier: Optional[float] = None,
236
+ norm_eps: Optional[float] = 1e-5,
237
+ learn_sigma: Optional[bool] = True,
238
+ qk_norm: Optional[bool] = True,
239
+ cross_attention_dim: Optional[int] = 2048,
240
+ scaling_factor: Optional[float] = 1.0,
241
+ ) -> None:
242
+ super().__init__()
243
+ self.sample_size = sample_size
244
+ self.patch_size = patch_size
245
+ self.in_channels = in_channels
246
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
247
+ self.hidden_size = hidden_size
248
+ self.num_attention_heads = num_attention_heads
249
+ self.head_dim = hidden_size // num_attention_heads
250
+ self.scaling_factor = scaling_factor
251
+
252
+ self.patch_embedder = LuminaPatchEmbed(
253
+ patch_size=patch_size, in_channels=in_channels, embed_dim=hidden_size, bias=True
254
+ )
255
+
256
+ self.pad_token = nn.Parameter(torch.empty(hidden_size))
257
+
258
+ self.time_caption_embed = LuminaCombinedTimestepCaptionEmbedding(
259
+ hidden_size=min(hidden_size, 1024), cross_attention_dim=cross_attention_dim
260
+ )
261
+
262
+ self.layers = nn.ModuleList(
263
+ [
264
+ LuminaNextDiTBlock(
265
+ hidden_size,
266
+ num_attention_heads,
267
+ num_kv_heads,
268
+ multiple_of,
269
+ ffn_dim_multiplier,
270
+ norm_eps,
271
+ qk_norm,
272
+ cross_attention_dim,
273
+ )
274
+ for _ in range(num_layers)
275
+ ]
276
+ )
277
+ self.norm_out = LuminaLayerNormContinuous(
278
+ embedding_dim=hidden_size,
279
+ conditioning_embedding_dim=min(hidden_size, 1024),
280
+ elementwise_affine=False,
281
+ eps=1e-6,
282
+ bias=True,
283
+ out_dim=patch_size * patch_size * self.out_channels,
284
+ )
285
+ # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels)
286
+
287
+ assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
288
+
289
+ def forward(
290
+ self,
291
+ hidden_states: torch.Tensor,
292
+ timestep: torch.Tensor,
293
+ encoder_hidden_states: torch.Tensor,
294
+ encoder_mask: torch.Tensor,
295
+ image_rotary_emb: torch.Tensor,
296
+ cross_attention_kwargs: Dict[str, Any] = None,
297
+ return_dict=True,
298
+ ) -> torch.Tensor:
299
+ """
300
+ Forward pass of LuminaNextDiT.
301
+
302
+ Parameters:
303
+ hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W).
304
+ timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,).
305
+ encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D).
306
+ encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L).
307
+ """
308
+ hidden_states, mask, img_size, image_rotary_emb = self.patch_embedder(hidden_states, image_rotary_emb)
309
+ image_rotary_emb = image_rotary_emb.to(hidden_states.device)
310
+
311
+ temb = self.time_caption_embed(timestep, encoder_hidden_states, encoder_mask)
312
+
313
+ encoder_mask = encoder_mask.bool()
314
+ for layer in self.layers:
315
+ hidden_states = layer(
316
+ hidden_states,
317
+ mask,
318
+ image_rotary_emb,
319
+ encoder_hidden_states,
320
+ encoder_mask,
321
+ temb=temb,
322
+ cross_attention_kwargs=cross_attention_kwargs,
323
+ )
324
+
325
+ hidden_states = self.norm_out(hidden_states, temb)
326
+
327
+ # unpatchify
328
+ height_tokens = width_tokens = self.patch_size
329
+ height, width = img_size[0]
330
+ batch_size = hidden_states.size(0)
331
+ sequence_length = (height // height_tokens) * (width // width_tokens)
332
+ hidden_states = hidden_states[:, :sequence_length].view(
333
+ batch_size, height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels
334
+ )
335
+ output = hidden_states.permute(0, 5, 1, 3, 2, 4).flatten(4, 5).flatten(2, 3)
336
+
337
+ if not return_dict:
338
+ return (output,)
339
+
340
+ return Transformer2DModelOutput(sample=output)