diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2025 The
|
1
|
+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
2
2
|
# All rights reserved.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -13,19 +13,19 @@
|
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
15
|
|
16
|
+
import inspect
|
16
17
|
import math
|
17
18
|
from typing import Any, Dict, Optional, Tuple, Union
|
18
19
|
|
19
20
|
import torch
|
20
21
|
import torch.nn as nn
|
21
|
-
import torch.nn.functional as F
|
22
22
|
|
23
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
24
24
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
25
|
-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
|
+
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
26
26
|
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
-
from ..attention import FeedForward
|
28
|
-
from ..
|
27
|
+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
28
|
+
from ..attention_dispatch import dispatch_attention_fn
|
29
29
|
from ..cache_utils import CacheMixin
|
30
30
|
from ..embeddings import PixArtAlphaTextProjection
|
31
31
|
from ..modeling_outputs import Transformer2DModelOutput
|
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
37
37
|
|
38
38
|
|
39
39
|
class LTXVideoAttentionProcessor2_0:
|
40
|
+
def __new__(cls, *args, **kwargs):
|
41
|
+
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
|
42
|
+
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
|
43
|
+
|
44
|
+
return LTXVideoAttnProcessor(*args, **kwargs)
|
45
|
+
|
46
|
+
|
47
|
+
class LTXVideoAttnProcessor:
|
40
48
|
r"""
|
41
|
-
Processor for implementing
|
42
|
-
|
49
|
+
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
|
50
|
+
model. It applies a normalization layer and rotary embedding on the query and key vector.
|
43
51
|
"""
|
44
52
|
|
53
|
+
_attention_backend = None
|
54
|
+
|
45
55
|
def __init__(self):
|
46
|
-
if
|
47
|
-
raise
|
48
|
-
"
|
56
|
+
if is_torch_version("<", "2.0"):
|
57
|
+
raise ValueError(
|
58
|
+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
49
59
|
)
|
50
60
|
|
51
61
|
def __call__(
|
52
62
|
self,
|
53
|
-
attn:
|
63
|
+
attn: "LTXAttention",
|
54
64
|
hidden_states: torch.Tensor,
|
55
65
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
56
66
|
attention_mask: Optional[torch.Tensor] = None,
|
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
|
|
78
88
|
query = apply_rotary_emb(query, image_rotary_emb)
|
79
89
|
key = apply_rotary_emb(key, image_rotary_emb)
|
80
90
|
|
81
|
-
query = query.unflatten(2, (attn.heads, -1))
|
82
|
-
key = key.unflatten(2, (attn.heads, -1))
|
83
|
-
value = value.unflatten(2, (attn.heads, -1))
|
84
|
-
|
85
|
-
hidden_states =
|
86
|
-
query,
|
91
|
+
query = query.unflatten(2, (attn.heads, -1))
|
92
|
+
key = key.unflatten(2, (attn.heads, -1))
|
93
|
+
value = value.unflatten(2, (attn.heads, -1))
|
94
|
+
|
95
|
+
hidden_states = dispatch_attention_fn(
|
96
|
+
query,
|
97
|
+
key,
|
98
|
+
value,
|
99
|
+
attn_mask=attention_mask,
|
100
|
+
dropout_p=0.0,
|
101
|
+
is_causal=False,
|
102
|
+
backend=self._attention_backend,
|
87
103
|
)
|
88
|
-
hidden_states = hidden_states.
|
104
|
+
hidden_states = hidden_states.flatten(2, 3)
|
89
105
|
hidden_states = hidden_states.to(query.dtype)
|
90
106
|
|
91
107
|
hidden_states = attn.to_out[0](hidden_states)
|
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
|
|
93
109
|
return hidden_states
|
94
110
|
|
95
111
|
|
112
|
+
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
|
113
|
+
_default_processor_cls = LTXVideoAttnProcessor
|
114
|
+
_available_processors = [LTXVideoAttnProcessor]
|
115
|
+
|
116
|
+
def __init__(
|
117
|
+
self,
|
118
|
+
query_dim: int,
|
119
|
+
heads: int = 8,
|
120
|
+
kv_heads: int = 8,
|
121
|
+
dim_head: int = 64,
|
122
|
+
dropout: float = 0.0,
|
123
|
+
bias: bool = True,
|
124
|
+
cross_attention_dim: Optional[int] = None,
|
125
|
+
out_bias: bool = True,
|
126
|
+
qk_norm: str = "rms_norm_across_heads",
|
127
|
+
processor=None,
|
128
|
+
):
|
129
|
+
super().__init__()
|
130
|
+
if qk_norm != "rms_norm_across_heads":
|
131
|
+
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
|
132
|
+
|
133
|
+
self.head_dim = dim_head
|
134
|
+
self.inner_dim = dim_head * heads
|
135
|
+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
136
|
+
self.query_dim = query_dim
|
137
|
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
138
|
+
self.use_bias = bias
|
139
|
+
self.dropout = dropout
|
140
|
+
self.out_dim = query_dim
|
141
|
+
self.heads = heads
|
142
|
+
|
143
|
+
norm_eps = 1e-5
|
144
|
+
norm_elementwise_affine = True
|
145
|
+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
146
|
+
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
147
|
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
148
|
+
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
149
|
+
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
150
|
+
self.to_out = torch.nn.ModuleList([])
|
151
|
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
152
|
+
self.to_out.append(torch.nn.Dropout(dropout))
|
153
|
+
|
154
|
+
if processor is None:
|
155
|
+
processor = self._default_processor_cls()
|
156
|
+
self.set_processor(processor)
|
157
|
+
|
158
|
+
def forward(
|
159
|
+
self,
|
160
|
+
hidden_states: torch.Tensor,
|
161
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
162
|
+
attention_mask: Optional[torch.Tensor] = None,
|
163
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
164
|
+
**kwargs,
|
165
|
+
) -> torch.Tensor:
|
166
|
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
167
|
+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
|
168
|
+
if len(unused_kwargs) > 0:
|
169
|
+
logger.warning(
|
170
|
+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
171
|
+
)
|
172
|
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
173
|
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
174
|
+
|
175
|
+
|
96
176
|
class LTXVideoRotaryPosEmbed(nn.Module):
|
97
177
|
def __init__(
|
98
178
|
self,
|
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
|
|
231
311
|
super().__init__()
|
232
312
|
|
233
313
|
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
234
|
-
self.attn1 =
|
314
|
+
self.attn1 = LTXAttention(
|
235
315
|
query_dim=dim,
|
236
316
|
heads=num_attention_heads,
|
237
317
|
kv_heads=num_attention_heads,
|
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
|
|
240
320
|
cross_attention_dim=None,
|
241
321
|
out_bias=attention_out_bias,
|
242
322
|
qk_norm=qk_norm,
|
243
|
-
processor=LTXVideoAttentionProcessor2_0(),
|
244
323
|
)
|
245
324
|
|
246
325
|
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
247
|
-
self.attn2 =
|
326
|
+
self.attn2 = LTXAttention(
|
248
327
|
query_dim=dim,
|
249
328
|
cross_attention_dim=cross_attention_dim,
|
250
329
|
heads=num_attention_heads,
|
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
|
|
253
332
|
bias=attention_bias,
|
254
333
|
out_bias=attention_out_bias,
|
255
334
|
qk_norm=qk_norm,
|
256
|
-
processor=LTXVideoAttentionProcessor2_0(),
|
257
335
|
)
|
258
336
|
|
259
337
|
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
|
|
299
377
|
|
300
378
|
|
301
379
|
@maybe_allow_in_graph
|
302
|
-
class LTXVideoTransformer3DModel(
|
380
|
+
class LTXVideoTransformer3DModel(
|
381
|
+
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
|
382
|
+
):
|
303
383
|
r"""
|
304
384
|
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
305
385
|
|
@@ -328,6 +408,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
|
328
408
|
|
329
409
|
_supports_gradient_checkpointing = True
|
330
410
|
_skip_layerwise_casting_patterns = ["norm"]
|
411
|
+
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
331
412
|
|
332
413
|
@register_to_config
|
333
414
|
def __init__(
|