diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +72 -26
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
|
|
1
|
+
# Copyright 2024 The Genmo team and The HuggingFace Team.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import math
|
17
|
+
from typing import Any, Dict, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
import torch.nn.functional as F
|
22
|
+
|
23
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
25
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
26
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
+
from ..attention import FeedForward
|
28
|
+
from ..attention_processor import Attention
|
29
|
+
from ..embeddings import PixArtAlphaTextProjection
|
30
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
31
|
+
from ..modeling_utils import ModelMixin
|
32
|
+
from ..normalization import AdaLayerNormSingle, RMSNorm
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
36
|
+
|
37
|
+
|
38
|
+
class LTXVideoAttentionProcessor2_0:
|
39
|
+
r"""
|
40
|
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
41
|
+
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
46
|
+
raise ImportError(
|
47
|
+
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
48
|
+
)
|
49
|
+
|
50
|
+
def __call__(
|
51
|
+
self,
|
52
|
+
attn: Attention,
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
55
|
+
attention_mask: Optional[torch.Tensor] = None,
|
56
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
57
|
+
) -> torch.Tensor:
|
58
|
+
batch_size, sequence_length, _ = (
|
59
|
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
60
|
+
)
|
61
|
+
|
62
|
+
if attention_mask is not None:
|
63
|
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
64
|
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
65
|
+
|
66
|
+
if encoder_hidden_states is None:
|
67
|
+
encoder_hidden_states = hidden_states
|
68
|
+
|
69
|
+
query = attn.to_q(hidden_states)
|
70
|
+
key = attn.to_k(encoder_hidden_states)
|
71
|
+
value = attn.to_v(encoder_hidden_states)
|
72
|
+
|
73
|
+
query = attn.norm_q(query)
|
74
|
+
key = attn.norm_k(key)
|
75
|
+
|
76
|
+
if image_rotary_emb is not None:
|
77
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
78
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
79
|
+
|
80
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
81
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
82
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
83
|
+
|
84
|
+
hidden_states = F.scaled_dot_product_attention(
|
85
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
86
|
+
)
|
87
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
88
|
+
hidden_states = hidden_states.to(query.dtype)
|
89
|
+
|
90
|
+
hidden_states = attn.to_out[0](hidden_states)
|
91
|
+
hidden_states = attn.to_out[1](hidden_states)
|
92
|
+
return hidden_states
|
93
|
+
|
94
|
+
|
95
|
+
class LTXVideoRotaryPosEmbed(nn.Module):
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
dim: int,
|
99
|
+
base_num_frames: int = 20,
|
100
|
+
base_height: int = 2048,
|
101
|
+
base_width: int = 2048,
|
102
|
+
patch_size: int = 1,
|
103
|
+
patch_size_t: int = 1,
|
104
|
+
theta: float = 10000.0,
|
105
|
+
) -> None:
|
106
|
+
super().__init__()
|
107
|
+
|
108
|
+
self.dim = dim
|
109
|
+
self.base_num_frames = base_num_frames
|
110
|
+
self.base_height = base_height
|
111
|
+
self.base_width = base_width
|
112
|
+
self.patch_size = patch_size
|
113
|
+
self.patch_size_t = patch_size_t
|
114
|
+
self.theta = theta
|
115
|
+
|
116
|
+
def forward(
|
117
|
+
self,
|
118
|
+
hidden_states: torch.Tensor,
|
119
|
+
num_frames: int,
|
120
|
+
height: int,
|
121
|
+
width: int,
|
122
|
+
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
123
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
124
|
+
batch_size = hidden_states.size(0)
|
125
|
+
|
126
|
+
# Always compute rope in fp32
|
127
|
+
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
128
|
+
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
129
|
+
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
130
|
+
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
131
|
+
grid = torch.stack(grid, dim=0)
|
132
|
+
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
133
|
+
|
134
|
+
if rope_interpolation_scale is not None:
|
135
|
+
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
|
136
|
+
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
|
137
|
+
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
|
138
|
+
|
139
|
+
grid = grid.flatten(2, 4).transpose(1, 2)
|
140
|
+
|
141
|
+
start = 1.0
|
142
|
+
end = self.theta
|
143
|
+
freqs = self.theta ** torch.linspace(
|
144
|
+
math.log(start, self.theta),
|
145
|
+
math.log(end, self.theta),
|
146
|
+
self.dim // 6,
|
147
|
+
device=hidden_states.device,
|
148
|
+
dtype=torch.float32,
|
149
|
+
)
|
150
|
+
freqs = freqs * math.pi / 2.0
|
151
|
+
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
|
152
|
+
freqs = freqs.transpose(-1, -2).flatten(2)
|
153
|
+
|
154
|
+
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
155
|
+
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
156
|
+
|
157
|
+
if self.dim % 6 != 0:
|
158
|
+
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
|
159
|
+
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
|
160
|
+
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
161
|
+
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
162
|
+
|
163
|
+
return cos_freqs, sin_freqs
|
164
|
+
|
165
|
+
|
166
|
+
@maybe_allow_in_graph
|
167
|
+
class LTXVideoTransformerBlock(nn.Module):
|
168
|
+
r"""
|
169
|
+
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
170
|
+
|
171
|
+
Args:
|
172
|
+
dim (`int`):
|
173
|
+
The number of channels in the input and output.
|
174
|
+
num_attention_heads (`int`):
|
175
|
+
The number of heads to use for multi-head attention.
|
176
|
+
attention_head_dim (`int`):
|
177
|
+
The number of channels in each head.
|
178
|
+
qk_norm (`str`, defaults to `"rms_norm"`):
|
179
|
+
The normalization layer to use.
|
180
|
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
181
|
+
Activation function to use in feed-forward.
|
182
|
+
eps (`float`, defaults to `1e-6`):
|
183
|
+
Epsilon value for normalization layers.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
dim: int,
|
189
|
+
num_attention_heads: int,
|
190
|
+
attention_head_dim: int,
|
191
|
+
cross_attention_dim: int,
|
192
|
+
qk_norm: str = "rms_norm_across_heads",
|
193
|
+
activation_fn: str = "gelu-approximate",
|
194
|
+
attention_bias: bool = True,
|
195
|
+
attention_out_bias: bool = True,
|
196
|
+
eps: float = 1e-6,
|
197
|
+
elementwise_affine: bool = False,
|
198
|
+
):
|
199
|
+
super().__init__()
|
200
|
+
|
201
|
+
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
202
|
+
self.attn1 = Attention(
|
203
|
+
query_dim=dim,
|
204
|
+
heads=num_attention_heads,
|
205
|
+
kv_heads=num_attention_heads,
|
206
|
+
dim_head=attention_head_dim,
|
207
|
+
bias=attention_bias,
|
208
|
+
cross_attention_dim=None,
|
209
|
+
out_bias=attention_out_bias,
|
210
|
+
qk_norm=qk_norm,
|
211
|
+
processor=LTXVideoAttentionProcessor2_0(),
|
212
|
+
)
|
213
|
+
|
214
|
+
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
215
|
+
self.attn2 = Attention(
|
216
|
+
query_dim=dim,
|
217
|
+
cross_attention_dim=cross_attention_dim,
|
218
|
+
heads=num_attention_heads,
|
219
|
+
kv_heads=num_attention_heads,
|
220
|
+
dim_head=attention_head_dim,
|
221
|
+
bias=attention_bias,
|
222
|
+
out_bias=attention_out_bias,
|
223
|
+
qk_norm=qk_norm,
|
224
|
+
processor=LTXVideoAttentionProcessor2_0(),
|
225
|
+
)
|
226
|
+
|
227
|
+
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
228
|
+
|
229
|
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
230
|
+
|
231
|
+
def forward(
|
232
|
+
self,
|
233
|
+
hidden_states: torch.Tensor,
|
234
|
+
encoder_hidden_states: torch.Tensor,
|
235
|
+
temb: torch.Tensor,
|
236
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
237
|
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
238
|
+
) -> torch.Tensor:
|
239
|
+
batch_size = hidden_states.size(0)
|
240
|
+
norm_hidden_states = self.norm1(hidden_states)
|
241
|
+
|
242
|
+
num_ada_params = self.scale_shift_table.shape[0]
|
243
|
+
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
244
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
245
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
246
|
+
|
247
|
+
attn_hidden_states = self.attn1(
|
248
|
+
hidden_states=norm_hidden_states,
|
249
|
+
encoder_hidden_states=None,
|
250
|
+
image_rotary_emb=image_rotary_emb,
|
251
|
+
)
|
252
|
+
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
253
|
+
|
254
|
+
attn_hidden_states = self.attn2(
|
255
|
+
hidden_states,
|
256
|
+
encoder_hidden_states=encoder_hidden_states,
|
257
|
+
image_rotary_emb=None,
|
258
|
+
attention_mask=encoder_attention_mask,
|
259
|
+
)
|
260
|
+
hidden_states = hidden_states + attn_hidden_states
|
261
|
+
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
262
|
+
|
263
|
+
ff_output = self.ff(norm_hidden_states)
|
264
|
+
hidden_states = hidden_states + ff_output * gate_mlp
|
265
|
+
|
266
|
+
return hidden_states
|
267
|
+
|
268
|
+
|
269
|
+
@maybe_allow_in_graph
|
270
|
+
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
271
|
+
r"""
|
272
|
+
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
273
|
+
|
274
|
+
Args:
|
275
|
+
in_channels (`int`, defaults to `128`):
|
276
|
+
The number of channels in the input.
|
277
|
+
out_channels (`int`, defaults to `128`):
|
278
|
+
The number of channels in the output.
|
279
|
+
patch_size (`int`, defaults to `1`):
|
280
|
+
The size of the spatial patches to use in the patch embedding layer.
|
281
|
+
patch_size_t (`int`, defaults to `1`):
|
282
|
+
The size of the tmeporal patches to use in the patch embedding layer.
|
283
|
+
num_attention_heads (`int`, defaults to `32`):
|
284
|
+
The number of heads to use for multi-head attention.
|
285
|
+
attention_head_dim (`int`, defaults to `64`):
|
286
|
+
The number of channels in each head.
|
287
|
+
cross_attention_dim (`int`, defaults to `2048 `):
|
288
|
+
The number of channels for cross attention heads.
|
289
|
+
num_layers (`int`, defaults to `28`):
|
290
|
+
The number of layers of Transformer blocks to use.
|
291
|
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
292
|
+
Activation function to use in feed-forward.
|
293
|
+
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
|
294
|
+
The normalization layer to use.
|
295
|
+
"""
|
296
|
+
|
297
|
+
_supports_gradient_checkpointing = True
|
298
|
+
|
299
|
+
@register_to_config
|
300
|
+
def __init__(
|
301
|
+
self,
|
302
|
+
in_channels: int = 128,
|
303
|
+
out_channels: int = 128,
|
304
|
+
patch_size: int = 1,
|
305
|
+
patch_size_t: int = 1,
|
306
|
+
num_attention_heads: int = 32,
|
307
|
+
attention_head_dim: int = 64,
|
308
|
+
cross_attention_dim: int = 2048,
|
309
|
+
num_layers: int = 28,
|
310
|
+
activation_fn: str = "gelu-approximate",
|
311
|
+
qk_norm: str = "rms_norm_across_heads",
|
312
|
+
norm_elementwise_affine: bool = False,
|
313
|
+
norm_eps: float = 1e-6,
|
314
|
+
caption_channels: int = 4096,
|
315
|
+
attention_bias: bool = True,
|
316
|
+
attention_out_bias: bool = True,
|
317
|
+
) -> None:
|
318
|
+
super().__init__()
|
319
|
+
|
320
|
+
out_channels = out_channels or in_channels
|
321
|
+
inner_dim = num_attention_heads * attention_head_dim
|
322
|
+
|
323
|
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
324
|
+
|
325
|
+
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
326
|
+
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
327
|
+
|
328
|
+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
329
|
+
|
330
|
+
self.rope = LTXVideoRotaryPosEmbed(
|
331
|
+
dim=inner_dim,
|
332
|
+
base_num_frames=20,
|
333
|
+
base_height=2048,
|
334
|
+
base_width=2048,
|
335
|
+
patch_size=patch_size,
|
336
|
+
patch_size_t=patch_size_t,
|
337
|
+
theta=10000.0,
|
338
|
+
)
|
339
|
+
|
340
|
+
self.transformer_blocks = nn.ModuleList(
|
341
|
+
[
|
342
|
+
LTXVideoTransformerBlock(
|
343
|
+
dim=inner_dim,
|
344
|
+
num_attention_heads=num_attention_heads,
|
345
|
+
attention_head_dim=attention_head_dim,
|
346
|
+
cross_attention_dim=cross_attention_dim,
|
347
|
+
qk_norm=qk_norm,
|
348
|
+
activation_fn=activation_fn,
|
349
|
+
attention_bias=attention_bias,
|
350
|
+
attention_out_bias=attention_out_bias,
|
351
|
+
eps=norm_eps,
|
352
|
+
elementwise_affine=norm_elementwise_affine,
|
353
|
+
)
|
354
|
+
for _ in range(num_layers)
|
355
|
+
]
|
356
|
+
)
|
357
|
+
|
358
|
+
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
359
|
+
self.proj_out = nn.Linear(inner_dim, out_channels)
|
360
|
+
|
361
|
+
self.gradient_checkpointing = False
|
362
|
+
|
363
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
364
|
+
if hasattr(module, "gradient_checkpointing"):
|
365
|
+
module.gradient_checkpointing = value
|
366
|
+
|
367
|
+
def forward(
|
368
|
+
self,
|
369
|
+
hidden_states: torch.Tensor,
|
370
|
+
encoder_hidden_states: torch.Tensor,
|
371
|
+
timestep: torch.LongTensor,
|
372
|
+
encoder_attention_mask: torch.Tensor,
|
373
|
+
num_frames: int,
|
374
|
+
height: int,
|
375
|
+
width: int,
|
376
|
+
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
377
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
378
|
+
return_dict: bool = True,
|
379
|
+
) -> torch.Tensor:
|
380
|
+
if attention_kwargs is not None:
|
381
|
+
attention_kwargs = attention_kwargs.copy()
|
382
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
383
|
+
else:
|
384
|
+
lora_scale = 1.0
|
385
|
+
|
386
|
+
if USE_PEFT_BACKEND:
|
387
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
388
|
+
scale_lora_layers(self, lora_scale)
|
389
|
+
else:
|
390
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
391
|
+
logger.warning(
|
392
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
393
|
+
)
|
394
|
+
|
395
|
+
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
396
|
+
|
397
|
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
398
|
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
399
|
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
400
|
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
401
|
+
|
402
|
+
batch_size = hidden_states.size(0)
|
403
|
+
hidden_states = self.proj_in(hidden_states)
|
404
|
+
|
405
|
+
temb, embedded_timestep = self.time_embed(
|
406
|
+
timestep.flatten(),
|
407
|
+
batch_size=batch_size,
|
408
|
+
hidden_dtype=hidden_states.dtype,
|
409
|
+
)
|
410
|
+
|
411
|
+
temb = temb.view(batch_size, -1, temb.size(-1))
|
412
|
+
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
413
|
+
|
414
|
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
415
|
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
416
|
+
|
417
|
+
for block in self.transformer_blocks:
|
418
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
419
|
+
|
420
|
+
def create_custom_forward(module, return_dict=None):
|
421
|
+
def custom_forward(*inputs):
|
422
|
+
if return_dict is not None:
|
423
|
+
return module(*inputs, return_dict=return_dict)
|
424
|
+
else:
|
425
|
+
return module(*inputs)
|
426
|
+
|
427
|
+
return custom_forward
|
428
|
+
|
429
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
430
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
431
|
+
create_custom_forward(block),
|
432
|
+
hidden_states,
|
433
|
+
encoder_hidden_states,
|
434
|
+
temb,
|
435
|
+
image_rotary_emb,
|
436
|
+
encoder_attention_mask,
|
437
|
+
**ckpt_kwargs,
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
hidden_states = block(
|
441
|
+
hidden_states=hidden_states,
|
442
|
+
encoder_hidden_states=encoder_hidden_states,
|
443
|
+
temb=temb,
|
444
|
+
image_rotary_emb=image_rotary_emb,
|
445
|
+
encoder_attention_mask=encoder_attention_mask,
|
446
|
+
)
|
447
|
+
|
448
|
+
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
449
|
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
450
|
+
|
451
|
+
hidden_states = self.norm_out(hidden_states)
|
452
|
+
hidden_states = hidden_states * (1 + scale) + shift
|
453
|
+
output = self.proj_out(hidden_states)
|
454
|
+
|
455
|
+
if USE_PEFT_BACKEND:
|
456
|
+
# remove `lora_scale` from each PEFT layer
|
457
|
+
unscale_lora_layers(self, lora_scale)
|
458
|
+
|
459
|
+
if not return_dict:
|
460
|
+
return (output,)
|
461
|
+
return Transformer2DModelOutput(sample=output)
|
462
|
+
|
463
|
+
|
464
|
+
def apply_rotary_emb(x, freqs):
|
465
|
+
cos, sin = freqs
|
466
|
+
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
|
467
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
468
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
469
|
+
return out
|