diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2025 The Genmo team and The HuggingFace Team.
1
+ # Copyright 2025 The Lightricks team and The HuggingFace Team.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import inspect
16
17
  import math
17
18
  from typing import Any, Dict, Optional, Tuple, Union
18
19
 
19
20
  import torch
20
21
  import torch.nn as nn
21
- import torch.nn.functional as F
22
22
 
23
23
  from ...configuration_utils import ConfigMixin, register_to_config
24
24
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25
- from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
26
  from ...utils.torch_utils import maybe_allow_in_graph
27
- from ..attention import FeedForward
28
- from ..attention_processor import Attention
27
+ from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28
+ from ..attention_dispatch import dispatch_attention_fn
29
29
  from ..cache_utils import CacheMixin
30
30
  from ..embeddings import PixArtAlphaTextProjection
31
31
  from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
37
 
38
38
 
39
39
  class LTXVideoAttentionProcessor2_0:
40
+ def __new__(cls, *args, **kwargs):
41
+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
42
+ deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43
+
44
+ return LTXVideoAttnProcessor(*args, **kwargs)
45
+
46
+
47
+ class LTXVideoAttnProcessor:
40
48
  r"""
41
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42
- used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50
+ model. It applies a normalization layer and rotary embedding on the query and key vector.
43
51
  """
44
52
 
53
+ _attention_backend = None
54
+
45
55
  def __init__(self):
46
- if not hasattr(F, "scaled_dot_product_attention"):
47
- raise ImportError(
48
- "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56
+ if is_torch_version("<", "2.0"):
57
+ raise ValueError(
58
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
49
59
  )
50
60
 
51
61
  def __call__(
52
62
  self,
53
- attn: Attention,
63
+ attn: "LTXAttention",
54
64
  hidden_states: torch.Tensor,
55
65
  encoder_hidden_states: Optional[torch.Tensor] = None,
56
66
  attention_mask: Optional[torch.Tensor] = None,
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
78
88
  query = apply_rotary_emb(query, image_rotary_emb)
79
89
  key = apply_rotary_emb(key, image_rotary_emb)
80
90
 
81
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84
-
85
- hidden_states = F.scaled_dot_product_attention(
86
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
91
+ query = query.unflatten(2, (attn.heads, -1))
92
+ key = key.unflatten(2, (attn.heads, -1))
93
+ value = value.unflatten(2, (attn.heads, -1))
94
+
95
+ hidden_states = dispatch_attention_fn(
96
+ query,
97
+ key,
98
+ value,
99
+ attn_mask=attention_mask,
100
+ dropout_p=0.0,
101
+ is_causal=False,
102
+ backend=self._attention_backend,
87
103
  )
88
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104
+ hidden_states = hidden_states.flatten(2, 3)
89
105
  hidden_states = hidden_states.to(query.dtype)
90
106
 
91
107
  hidden_states = attn.to_out[0](hidden_states)
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
93
109
  return hidden_states
94
110
 
95
111
 
112
+ class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113
+ _default_processor_cls = LTXVideoAttnProcessor
114
+ _available_processors = [LTXVideoAttnProcessor]
115
+
116
+ def __init__(
117
+ self,
118
+ query_dim: int,
119
+ heads: int = 8,
120
+ kv_heads: int = 8,
121
+ dim_head: int = 64,
122
+ dropout: float = 0.0,
123
+ bias: bool = True,
124
+ cross_attention_dim: Optional[int] = None,
125
+ out_bias: bool = True,
126
+ qk_norm: str = "rms_norm_across_heads",
127
+ processor=None,
128
+ ):
129
+ super().__init__()
130
+ if qk_norm != "rms_norm_across_heads":
131
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132
+
133
+ self.head_dim = dim_head
134
+ self.inner_dim = dim_head * heads
135
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136
+ self.query_dim = query_dim
137
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138
+ self.use_bias = bias
139
+ self.dropout = dropout
140
+ self.out_dim = query_dim
141
+ self.heads = heads
142
+
143
+ norm_eps = 1e-5
144
+ norm_elementwise_affine = True
145
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
147
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
148
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
150
+ self.to_out = torch.nn.ModuleList([])
151
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
152
+ self.to_out.append(torch.nn.Dropout(dropout))
153
+
154
+ if processor is None:
155
+ processor = self._default_processor_cls()
156
+ self.set_processor(processor)
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ encoder_hidden_states: Optional[torch.Tensor] = None,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ image_rotary_emb: Optional[torch.Tensor] = None,
164
+ **kwargs,
165
+ ) -> torch.Tensor:
166
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
167
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
168
+ if len(unused_kwargs) > 0:
169
+ logger.warning(
170
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
171
+ )
172
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
173
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
174
+
175
+
96
176
  class LTXVideoRotaryPosEmbed(nn.Module):
97
177
  def __init__(
98
178
  self,
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
231
311
  super().__init__()
232
312
 
233
313
  self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
234
- self.attn1 = Attention(
314
+ self.attn1 = LTXAttention(
235
315
  query_dim=dim,
236
316
  heads=num_attention_heads,
237
317
  kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
240
320
  cross_attention_dim=None,
241
321
  out_bias=attention_out_bias,
242
322
  qk_norm=qk_norm,
243
- processor=LTXVideoAttentionProcessor2_0(),
244
323
  )
245
324
 
246
325
  self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
247
- self.attn2 = Attention(
326
+ self.attn2 = LTXAttention(
248
327
  query_dim=dim,
249
328
  cross_attention_dim=cross_attention_dim,
250
329
  heads=num_attention_heads,
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
253
332
  bias=attention_bias,
254
333
  out_bias=attention_out_bias,
255
334
  qk_norm=qk_norm,
256
- processor=LTXVideoAttentionProcessor2_0(),
257
335
  )
258
336
 
259
337
  self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
299
377
 
300
378
 
301
379
  @maybe_allow_in_graph
302
- class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
380
+ class LTXVideoTransformer3DModel(
381
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
382
+ ):
303
383
  r"""
304
384
  A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305
385
 
@@ -328,6 +408,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
328
408
 
329
409
  _supports_gradient_checkpointing = True
330
410
  _skip_layerwise_casting_patterns = ["norm"]
411
+ _repeated_blocks = ["LTXVideoTransformerBlock"]
331
412
 
332
413
  @register_to_config
333
414
  def __init__(