diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
1
+ # Copyright 2024 The Genmo team and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Any, Dict, Optional, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
+ from ...utils.torch_utils import maybe_allow_in_graph
27
+ from ..attention import FeedForward
28
+ from ..attention_processor import Attention
29
+ from ..embeddings import PixArtAlphaTextProjection
30
+ from ..modeling_outputs import Transformer2DModelOutput
31
+ from ..modeling_utils import ModelMixin
32
+ from ..normalization import AdaLayerNormSingle, RMSNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class LTXVideoAttentionProcessor2_0:
39
+ r"""
40
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
41
+ used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
42
+ """
43
+
44
+ def __init__(self):
45
+ if not hasattr(F, "scaled_dot_product_attention"):
46
+ raise ImportError(
47
+ "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
48
+ )
49
+
50
+ def __call__(
51
+ self,
52
+ attn: Attention,
53
+ hidden_states: torch.Tensor,
54
+ encoder_hidden_states: Optional[torch.Tensor] = None,
55
+ attention_mask: Optional[torch.Tensor] = None,
56
+ image_rotary_emb: Optional[torch.Tensor] = None,
57
+ ) -> torch.Tensor:
58
+ batch_size, sequence_length, _ = (
59
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
60
+ )
61
+
62
+ if attention_mask is not None:
63
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
64
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
65
+
66
+ if encoder_hidden_states is None:
67
+ encoder_hidden_states = hidden_states
68
+
69
+ query = attn.to_q(hidden_states)
70
+ key = attn.to_k(encoder_hidden_states)
71
+ value = attn.to_v(encoder_hidden_states)
72
+
73
+ query = attn.norm_q(query)
74
+ key = attn.norm_k(key)
75
+
76
+ if image_rotary_emb is not None:
77
+ query = apply_rotary_emb(query, image_rotary_emb)
78
+ key = apply_rotary_emb(key, image_rotary_emb)
79
+
80
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
81
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83
+
84
+ hidden_states = F.scaled_dot_product_attention(
85
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
86
+ )
87
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
88
+ hidden_states = hidden_states.to(query.dtype)
89
+
90
+ hidden_states = attn.to_out[0](hidden_states)
91
+ hidden_states = attn.to_out[1](hidden_states)
92
+ return hidden_states
93
+
94
+
95
+ class LTXVideoRotaryPosEmbed(nn.Module):
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ base_num_frames: int = 20,
100
+ base_height: int = 2048,
101
+ base_width: int = 2048,
102
+ patch_size: int = 1,
103
+ patch_size_t: int = 1,
104
+ theta: float = 10000.0,
105
+ ) -> None:
106
+ super().__init__()
107
+
108
+ self.dim = dim
109
+ self.base_num_frames = base_num_frames
110
+ self.base_height = base_height
111
+ self.base_width = base_width
112
+ self.patch_size = patch_size
113
+ self.patch_size_t = patch_size_t
114
+ self.theta = theta
115
+
116
+ def forward(
117
+ self,
118
+ hidden_states: torch.Tensor,
119
+ num_frames: int,
120
+ height: int,
121
+ width: int,
122
+ rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
123
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
124
+ batch_size = hidden_states.size(0)
125
+
126
+ # Always compute rope in fp32
127
+ grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
128
+ grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
129
+ grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
130
+ grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
131
+ grid = torch.stack(grid, dim=0)
132
+ grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
133
+
134
+ if rope_interpolation_scale is not None:
135
+ grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
136
+ grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
137
+ grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
138
+
139
+ grid = grid.flatten(2, 4).transpose(1, 2)
140
+
141
+ start = 1.0
142
+ end = self.theta
143
+ freqs = self.theta ** torch.linspace(
144
+ math.log(start, self.theta),
145
+ math.log(end, self.theta),
146
+ self.dim // 6,
147
+ device=hidden_states.device,
148
+ dtype=torch.float32,
149
+ )
150
+ freqs = freqs * math.pi / 2.0
151
+ freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
152
+ freqs = freqs.transpose(-1, -2).flatten(2)
153
+
154
+ cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
155
+ sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
156
+
157
+ if self.dim % 6 != 0:
158
+ cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
159
+ sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
160
+ cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
161
+ sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
162
+
163
+ return cos_freqs, sin_freqs
164
+
165
+
166
+ @maybe_allow_in_graph
167
+ class LTXVideoTransformerBlock(nn.Module):
168
+ r"""
169
+ Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
170
+
171
+ Args:
172
+ dim (`int`):
173
+ The number of channels in the input and output.
174
+ num_attention_heads (`int`):
175
+ The number of heads to use for multi-head attention.
176
+ attention_head_dim (`int`):
177
+ The number of channels in each head.
178
+ qk_norm (`str`, defaults to `"rms_norm"`):
179
+ The normalization layer to use.
180
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
181
+ Activation function to use in feed-forward.
182
+ eps (`float`, defaults to `1e-6`):
183
+ Epsilon value for normalization layers.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ dim: int,
189
+ num_attention_heads: int,
190
+ attention_head_dim: int,
191
+ cross_attention_dim: int,
192
+ qk_norm: str = "rms_norm_across_heads",
193
+ activation_fn: str = "gelu-approximate",
194
+ attention_bias: bool = True,
195
+ attention_out_bias: bool = True,
196
+ eps: float = 1e-6,
197
+ elementwise_affine: bool = False,
198
+ ):
199
+ super().__init__()
200
+
201
+ self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
202
+ self.attn1 = Attention(
203
+ query_dim=dim,
204
+ heads=num_attention_heads,
205
+ kv_heads=num_attention_heads,
206
+ dim_head=attention_head_dim,
207
+ bias=attention_bias,
208
+ cross_attention_dim=None,
209
+ out_bias=attention_out_bias,
210
+ qk_norm=qk_norm,
211
+ processor=LTXVideoAttentionProcessor2_0(),
212
+ )
213
+
214
+ self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
215
+ self.attn2 = Attention(
216
+ query_dim=dim,
217
+ cross_attention_dim=cross_attention_dim,
218
+ heads=num_attention_heads,
219
+ kv_heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ bias=attention_bias,
222
+ out_bias=attention_out_bias,
223
+ qk_norm=qk_norm,
224
+ processor=LTXVideoAttentionProcessor2_0(),
225
+ )
226
+
227
+ self.ff = FeedForward(dim, activation_fn=activation_fn)
228
+
229
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
230
+
231
+ def forward(
232
+ self,
233
+ hidden_states: torch.Tensor,
234
+ encoder_hidden_states: torch.Tensor,
235
+ temb: torch.Tensor,
236
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
237
+ encoder_attention_mask: Optional[torch.Tensor] = None,
238
+ ) -> torch.Tensor:
239
+ batch_size = hidden_states.size(0)
240
+ norm_hidden_states = self.norm1(hidden_states)
241
+
242
+ num_ada_params = self.scale_shift_table.shape[0]
243
+ ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
244
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
245
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
246
+
247
+ attn_hidden_states = self.attn1(
248
+ hidden_states=norm_hidden_states,
249
+ encoder_hidden_states=None,
250
+ image_rotary_emb=image_rotary_emb,
251
+ )
252
+ hidden_states = hidden_states + attn_hidden_states * gate_msa
253
+
254
+ attn_hidden_states = self.attn2(
255
+ hidden_states,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ image_rotary_emb=None,
258
+ attention_mask=encoder_attention_mask,
259
+ )
260
+ hidden_states = hidden_states + attn_hidden_states
261
+ norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
262
+
263
+ ff_output = self.ff(norm_hidden_states)
264
+ hidden_states = hidden_states + ff_output * gate_mlp
265
+
266
+ return hidden_states
267
+
268
+
269
+ @maybe_allow_in_graph
270
+ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
271
+ r"""
272
+ A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
273
+
274
+ Args:
275
+ in_channels (`int`, defaults to `128`):
276
+ The number of channels in the input.
277
+ out_channels (`int`, defaults to `128`):
278
+ The number of channels in the output.
279
+ patch_size (`int`, defaults to `1`):
280
+ The size of the spatial patches to use in the patch embedding layer.
281
+ patch_size_t (`int`, defaults to `1`):
282
+ The size of the tmeporal patches to use in the patch embedding layer.
283
+ num_attention_heads (`int`, defaults to `32`):
284
+ The number of heads to use for multi-head attention.
285
+ attention_head_dim (`int`, defaults to `64`):
286
+ The number of channels in each head.
287
+ cross_attention_dim (`int`, defaults to `2048 `):
288
+ The number of channels for cross attention heads.
289
+ num_layers (`int`, defaults to `28`):
290
+ The number of layers of Transformer blocks to use.
291
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
292
+ Activation function to use in feed-forward.
293
+ qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
294
+ The normalization layer to use.
295
+ """
296
+
297
+ _supports_gradient_checkpointing = True
298
+
299
+ @register_to_config
300
+ def __init__(
301
+ self,
302
+ in_channels: int = 128,
303
+ out_channels: int = 128,
304
+ patch_size: int = 1,
305
+ patch_size_t: int = 1,
306
+ num_attention_heads: int = 32,
307
+ attention_head_dim: int = 64,
308
+ cross_attention_dim: int = 2048,
309
+ num_layers: int = 28,
310
+ activation_fn: str = "gelu-approximate",
311
+ qk_norm: str = "rms_norm_across_heads",
312
+ norm_elementwise_affine: bool = False,
313
+ norm_eps: float = 1e-6,
314
+ caption_channels: int = 4096,
315
+ attention_bias: bool = True,
316
+ attention_out_bias: bool = True,
317
+ ) -> None:
318
+ super().__init__()
319
+
320
+ out_channels = out_channels or in_channels
321
+ inner_dim = num_attention_heads * attention_head_dim
322
+
323
+ self.proj_in = nn.Linear(in_channels, inner_dim)
324
+
325
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
326
+ self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
327
+
328
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
329
+
330
+ self.rope = LTXVideoRotaryPosEmbed(
331
+ dim=inner_dim,
332
+ base_num_frames=20,
333
+ base_height=2048,
334
+ base_width=2048,
335
+ patch_size=patch_size,
336
+ patch_size_t=patch_size_t,
337
+ theta=10000.0,
338
+ )
339
+
340
+ self.transformer_blocks = nn.ModuleList(
341
+ [
342
+ LTXVideoTransformerBlock(
343
+ dim=inner_dim,
344
+ num_attention_heads=num_attention_heads,
345
+ attention_head_dim=attention_head_dim,
346
+ cross_attention_dim=cross_attention_dim,
347
+ qk_norm=qk_norm,
348
+ activation_fn=activation_fn,
349
+ attention_bias=attention_bias,
350
+ attention_out_bias=attention_out_bias,
351
+ eps=norm_eps,
352
+ elementwise_affine=norm_elementwise_affine,
353
+ )
354
+ for _ in range(num_layers)
355
+ ]
356
+ )
357
+
358
+ self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
359
+ self.proj_out = nn.Linear(inner_dim, out_channels)
360
+
361
+ self.gradient_checkpointing = False
362
+
363
+ def _set_gradient_checkpointing(self, module, value=False):
364
+ if hasattr(module, "gradient_checkpointing"):
365
+ module.gradient_checkpointing = value
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.Tensor,
370
+ encoder_hidden_states: torch.Tensor,
371
+ timestep: torch.LongTensor,
372
+ encoder_attention_mask: torch.Tensor,
373
+ num_frames: int,
374
+ height: int,
375
+ width: int,
376
+ rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
377
+ attention_kwargs: Optional[Dict[str, Any]] = None,
378
+ return_dict: bool = True,
379
+ ) -> torch.Tensor:
380
+ if attention_kwargs is not None:
381
+ attention_kwargs = attention_kwargs.copy()
382
+ lora_scale = attention_kwargs.pop("scale", 1.0)
383
+ else:
384
+ lora_scale = 1.0
385
+
386
+ if USE_PEFT_BACKEND:
387
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
388
+ scale_lora_layers(self, lora_scale)
389
+ else:
390
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
391
+ logger.warning(
392
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
393
+ )
394
+
395
+ image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
396
+
397
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
398
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
399
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
400
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
401
+
402
+ batch_size = hidden_states.size(0)
403
+ hidden_states = self.proj_in(hidden_states)
404
+
405
+ temb, embedded_timestep = self.time_embed(
406
+ timestep.flatten(),
407
+ batch_size=batch_size,
408
+ hidden_dtype=hidden_states.dtype,
409
+ )
410
+
411
+ temb = temb.view(batch_size, -1, temb.size(-1))
412
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
413
+
414
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
415
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
416
+
417
+ for block in self.transformer_blocks:
418
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
419
+
420
+ def create_custom_forward(module, return_dict=None):
421
+ def custom_forward(*inputs):
422
+ if return_dict is not None:
423
+ return module(*inputs, return_dict=return_dict)
424
+ else:
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
430
+ hidden_states = torch.utils.checkpoint.checkpoint(
431
+ create_custom_forward(block),
432
+ hidden_states,
433
+ encoder_hidden_states,
434
+ temb,
435
+ image_rotary_emb,
436
+ encoder_attention_mask,
437
+ **ckpt_kwargs,
438
+ )
439
+ else:
440
+ hidden_states = block(
441
+ hidden_states=hidden_states,
442
+ encoder_hidden_states=encoder_hidden_states,
443
+ temb=temb,
444
+ image_rotary_emb=image_rotary_emb,
445
+ encoder_attention_mask=encoder_attention_mask,
446
+ )
447
+
448
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
449
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
450
+
451
+ hidden_states = self.norm_out(hidden_states)
452
+ hidden_states = hidden_states * (1 + scale) + shift
453
+ output = self.proj_out(hidden_states)
454
+
455
+ if USE_PEFT_BACKEND:
456
+ # remove `lora_scale` from each PEFT layer
457
+ unscale_lora_layers(self, lora_scale)
458
+
459
+ if not return_dict:
460
+ return (output,)
461
+ return Transformer2DModelOutput(sample=output)
462
+
463
+
464
+ def apply_rotary_emb(x, freqs):
465
+ cos, sin = freqs
466
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
467
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
468
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
469
+ return out