diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,607 @@
|
|
1
|
+
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
|
22
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
25
|
+
from ..attention import FeedForward
|
26
|
+
from ..attention_processor import Attention
|
27
|
+
from ..cache_utils import CacheMixin
|
28
|
+
from ..embeddings import (
|
29
|
+
PixArtAlphaTextProjection,
|
30
|
+
TimestepEmbedding,
|
31
|
+
get_1d_rotary_pos_embed,
|
32
|
+
get_1d_sincos_pos_embed_from_grid,
|
33
|
+
)
|
34
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
35
|
+
from ..modeling_utils import ModelMixin, get_parameter_dtype
|
36
|
+
from ..normalization import FP32LayerNorm
|
37
|
+
|
38
|
+
|
39
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40
|
+
|
41
|
+
|
42
|
+
class SkyReelsV2AttnProcessor2_0:
|
43
|
+
def __init__(self):
|
44
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
45
|
+
raise ImportError(
|
46
|
+
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
47
|
+
)
|
48
|
+
|
49
|
+
def __call__(
|
50
|
+
self,
|
51
|
+
attn: Attention,
|
52
|
+
hidden_states: torch.Tensor,
|
53
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
54
|
+
attention_mask: Optional[torch.Tensor] = None,
|
55
|
+
rotary_emb: Optional[torch.Tensor] = None,
|
56
|
+
) -> torch.Tensor:
|
57
|
+
encoder_hidden_states_img = None
|
58
|
+
if attn.add_k_proj is not None:
|
59
|
+
# 512 is the context length of the text encoder, hardcoded for now
|
60
|
+
image_context_length = encoder_hidden_states.shape[1] - 512
|
61
|
+
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
62
|
+
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
63
|
+
if encoder_hidden_states is None:
|
64
|
+
encoder_hidden_states = hidden_states
|
65
|
+
|
66
|
+
query = attn.to_q(hidden_states)
|
67
|
+
key = attn.to_k(encoder_hidden_states)
|
68
|
+
value = attn.to_v(encoder_hidden_states)
|
69
|
+
|
70
|
+
if attn.norm_q is not None:
|
71
|
+
query = attn.norm_q(query)
|
72
|
+
if attn.norm_k is not None:
|
73
|
+
key = attn.norm_k(key)
|
74
|
+
|
75
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
76
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
77
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
78
|
+
|
79
|
+
if rotary_emb is not None:
|
80
|
+
|
81
|
+
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
82
|
+
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
|
83
|
+
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
84
|
+
return x_out.type_as(hidden_states)
|
85
|
+
|
86
|
+
query = apply_rotary_emb(query, rotary_emb)
|
87
|
+
key = apply_rotary_emb(key, rotary_emb)
|
88
|
+
|
89
|
+
# I2V task
|
90
|
+
hidden_states_img = None
|
91
|
+
if encoder_hidden_states_img is not None:
|
92
|
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
93
|
+
key_img = attn.norm_added_k(key_img)
|
94
|
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
95
|
+
|
96
|
+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
97
|
+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
98
|
+
|
99
|
+
hidden_states_img = F.scaled_dot_product_attention(
|
100
|
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
101
|
+
)
|
102
|
+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
103
|
+
hidden_states_img = hidden_states_img.type_as(query)
|
104
|
+
|
105
|
+
hidden_states = F.scaled_dot_product_attention(
|
106
|
+
query,
|
107
|
+
key,
|
108
|
+
value,
|
109
|
+
attn_mask=attention_mask,
|
110
|
+
dropout_p=0.0,
|
111
|
+
is_causal=False,
|
112
|
+
)
|
113
|
+
|
114
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
115
|
+
hidden_states = hidden_states.type_as(query)
|
116
|
+
|
117
|
+
if hidden_states_img is not None:
|
118
|
+
hidden_states = hidden_states + hidden_states_img
|
119
|
+
|
120
|
+
hidden_states = attn.to_out[0](hidden_states)
|
121
|
+
hidden_states = attn.to_out[1](hidden_states)
|
122
|
+
return hidden_states
|
123
|
+
|
124
|
+
|
125
|
+
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
|
126
|
+
class SkyReelsV2ImageEmbedding(torch.nn.Module):
|
127
|
+
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
128
|
+
super().__init__()
|
129
|
+
|
130
|
+
self.norm1 = FP32LayerNorm(in_features)
|
131
|
+
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
132
|
+
self.norm2 = FP32LayerNorm(out_features)
|
133
|
+
if pos_embed_seq_len is not None:
|
134
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
135
|
+
else:
|
136
|
+
self.pos_embed = None
|
137
|
+
|
138
|
+
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
139
|
+
if self.pos_embed is not None:
|
140
|
+
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
141
|
+
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
142
|
+
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
143
|
+
|
144
|
+
hidden_states = self.norm1(encoder_hidden_states_image)
|
145
|
+
hidden_states = self.ff(hidden_states)
|
146
|
+
hidden_states = self.norm2(hidden_states)
|
147
|
+
return hidden_states
|
148
|
+
|
149
|
+
|
150
|
+
class SkyReelsV2Timesteps(nn.Module):
|
151
|
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
|
152
|
+
super().__init__()
|
153
|
+
self.num_channels = num_channels
|
154
|
+
self.output_type = output_type
|
155
|
+
self.flip_sin_to_cos = flip_sin_to_cos
|
156
|
+
|
157
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
158
|
+
original_shape = timesteps.shape
|
159
|
+
t_emb = get_1d_sincos_pos_embed_from_grid(
|
160
|
+
self.num_channels,
|
161
|
+
timesteps,
|
162
|
+
output_type=self.output_type,
|
163
|
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
164
|
+
)
|
165
|
+
# Reshape back to maintain batch structure
|
166
|
+
if len(original_shape) > 1:
|
167
|
+
t_emb = t_emb.reshape(*original_shape, self.num_channels)
|
168
|
+
return t_emb
|
169
|
+
|
170
|
+
|
171
|
+
class SkyReelsV2TimeTextImageEmbedding(nn.Module):
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
dim: int,
|
175
|
+
time_freq_dim: int,
|
176
|
+
time_proj_dim: int,
|
177
|
+
text_embed_dim: int,
|
178
|
+
image_embed_dim: Optional[int] = None,
|
179
|
+
pos_embed_seq_len: Optional[int] = None,
|
180
|
+
):
|
181
|
+
super().__init__()
|
182
|
+
|
183
|
+
self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
|
184
|
+
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
185
|
+
self.act_fn = nn.SiLU()
|
186
|
+
self.time_proj = nn.Linear(dim, time_proj_dim)
|
187
|
+
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
188
|
+
|
189
|
+
self.image_embedder = None
|
190
|
+
if image_embed_dim is not None:
|
191
|
+
self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
192
|
+
|
193
|
+
def forward(
|
194
|
+
self,
|
195
|
+
timestep: torch.Tensor,
|
196
|
+
encoder_hidden_states: torch.Tensor,
|
197
|
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
198
|
+
):
|
199
|
+
timestep = self.timesteps_proj(timestep)
|
200
|
+
|
201
|
+
time_embedder_dtype = get_parameter_dtype(self.time_embedder)
|
202
|
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
203
|
+
timestep = timestep.to(time_embedder_dtype)
|
204
|
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
205
|
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
206
|
+
|
207
|
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
208
|
+
if encoder_hidden_states_image is not None:
|
209
|
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
210
|
+
|
211
|
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
212
|
+
|
213
|
+
|
214
|
+
class SkyReelsV2RotaryPosEmbed(nn.Module):
|
215
|
+
def __init__(
|
216
|
+
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
217
|
+
):
|
218
|
+
super().__init__()
|
219
|
+
|
220
|
+
self.attention_head_dim = attention_head_dim
|
221
|
+
self.patch_size = patch_size
|
222
|
+
self.max_seq_len = max_seq_len
|
223
|
+
|
224
|
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
225
|
+
t_dim = attention_head_dim - h_dim - w_dim
|
226
|
+
|
227
|
+
freqs = []
|
228
|
+
for dim in [t_dim, h_dim, w_dim]:
|
229
|
+
freq = get_1d_rotary_pos_embed(
|
230
|
+
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
|
231
|
+
)
|
232
|
+
freqs.append(freq)
|
233
|
+
self.freqs = torch.cat(freqs, dim=1)
|
234
|
+
|
235
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
236
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
237
|
+
p_t, p_h, p_w = self.patch_size
|
238
|
+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
239
|
+
|
240
|
+
freqs = self.freqs.to(hidden_states.device)
|
241
|
+
freqs = freqs.split_with_sizes(
|
242
|
+
[
|
243
|
+
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
244
|
+
self.attention_head_dim // 6,
|
245
|
+
self.attention_head_dim // 6,
|
246
|
+
],
|
247
|
+
dim=1,
|
248
|
+
)
|
249
|
+
|
250
|
+
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
251
|
+
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
252
|
+
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
253
|
+
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
254
|
+
return freqs
|
255
|
+
|
256
|
+
|
257
|
+
class SkyReelsV2TransformerBlock(nn.Module):
|
258
|
+
def __init__(
|
259
|
+
self,
|
260
|
+
dim: int,
|
261
|
+
ffn_dim: int,
|
262
|
+
num_heads: int,
|
263
|
+
qk_norm: str = "rms_norm_across_heads",
|
264
|
+
cross_attn_norm: bool = False,
|
265
|
+
eps: float = 1e-6,
|
266
|
+
added_kv_proj_dim: Optional[int] = None,
|
267
|
+
):
|
268
|
+
super().__init__()
|
269
|
+
|
270
|
+
# 1. Self-attention
|
271
|
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
272
|
+
self.attn1 = Attention(
|
273
|
+
query_dim=dim,
|
274
|
+
heads=num_heads,
|
275
|
+
kv_heads=num_heads,
|
276
|
+
dim_head=dim // num_heads,
|
277
|
+
qk_norm=qk_norm,
|
278
|
+
eps=eps,
|
279
|
+
bias=True,
|
280
|
+
cross_attention_dim=None,
|
281
|
+
out_bias=True,
|
282
|
+
processor=SkyReelsV2AttnProcessor2_0(),
|
283
|
+
)
|
284
|
+
|
285
|
+
# 2. Cross-attention
|
286
|
+
self.attn2 = Attention(
|
287
|
+
query_dim=dim,
|
288
|
+
heads=num_heads,
|
289
|
+
kv_heads=num_heads,
|
290
|
+
dim_head=dim // num_heads,
|
291
|
+
qk_norm=qk_norm,
|
292
|
+
eps=eps,
|
293
|
+
bias=True,
|
294
|
+
cross_attention_dim=None,
|
295
|
+
out_bias=True,
|
296
|
+
added_kv_proj_dim=added_kv_proj_dim,
|
297
|
+
added_proj_bias=True,
|
298
|
+
processor=SkyReelsV2AttnProcessor2_0(),
|
299
|
+
)
|
300
|
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
301
|
+
|
302
|
+
# 3. Feed-forward
|
303
|
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
304
|
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
305
|
+
|
306
|
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
307
|
+
|
308
|
+
def forward(
|
309
|
+
self,
|
310
|
+
hidden_states: torch.Tensor,
|
311
|
+
encoder_hidden_states: torch.Tensor,
|
312
|
+
temb: torch.Tensor,
|
313
|
+
rotary_emb: torch.Tensor,
|
314
|
+
attention_mask: torch.Tensor,
|
315
|
+
) -> torch.Tensor:
|
316
|
+
if temb.dim() == 3:
|
317
|
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
318
|
+
self.scale_shift_table + temb.float()
|
319
|
+
).chunk(6, dim=1)
|
320
|
+
elif temb.dim() == 4:
|
321
|
+
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
|
322
|
+
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
|
323
|
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
|
324
|
+
# 1. Self-attention
|
325
|
+
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
326
|
+
attn_output = self.attn1(
|
327
|
+
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
|
328
|
+
)
|
329
|
+
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
330
|
+
# 2. Cross-attention
|
331
|
+
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
332
|
+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
333
|
+
hidden_states = hidden_states + attn_output
|
334
|
+
|
335
|
+
# 3. Feed-forward
|
336
|
+
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
337
|
+
hidden_states
|
338
|
+
)
|
339
|
+
ff_output = self.ffn(norm_hidden_states)
|
340
|
+
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
341
|
+
return hidden_states
|
342
|
+
|
343
|
+
|
344
|
+
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
345
|
+
r"""
|
346
|
+
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
350
|
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
351
|
+
num_attention_heads (`int`, defaults to `16`):
|
352
|
+
Fixed length for text embeddings.
|
353
|
+
attention_head_dim (`int`, defaults to `128`):
|
354
|
+
The number of channels in each head.
|
355
|
+
in_channels (`int`, defaults to `16`):
|
356
|
+
The number of channels in the input.
|
357
|
+
out_channels (`int`, defaults to `16`):
|
358
|
+
The number of channels in the output.
|
359
|
+
text_dim (`int`, defaults to `4096`):
|
360
|
+
Input dimension for text embeddings.
|
361
|
+
freq_dim (`int`, defaults to `256`):
|
362
|
+
Dimension for sinusoidal time embeddings.
|
363
|
+
ffn_dim (`int`, defaults to `8192`):
|
364
|
+
Intermediate dimension in feed-forward network.
|
365
|
+
num_layers (`int`, defaults to `32`):
|
366
|
+
The number of layers of transformer blocks to use.
|
367
|
+
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
368
|
+
Window size for local attention (-1 indicates global attention).
|
369
|
+
cross_attn_norm (`bool`, defaults to `True`):
|
370
|
+
Enable cross-attention normalization.
|
371
|
+
qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
|
372
|
+
Enable query/key normalization.
|
373
|
+
eps (`float`, defaults to `1e-6`):
|
374
|
+
Epsilon value for normalization layers.
|
375
|
+
inject_sample_info (`bool`, defaults to `False`):
|
376
|
+
Whether to inject sample information into the model.
|
377
|
+
image_dim (`int`, *optional*):
|
378
|
+
The dimension of the image embeddings.
|
379
|
+
added_kv_proj_dim (`int`, *optional*):
|
380
|
+
The dimension of the added key/value projection.
|
381
|
+
rope_max_seq_len (`int`, defaults to `1024`):
|
382
|
+
The maximum sequence length for the rotary embeddings.
|
383
|
+
pos_embed_seq_len (`int`, *optional*):
|
384
|
+
The sequence length for the positional embeddings.
|
385
|
+
"""
|
386
|
+
|
387
|
+
_supports_gradient_checkpointing = True
|
388
|
+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
389
|
+
_no_split_modules = ["SkyReelsV2TransformerBlock"]
|
390
|
+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
391
|
+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
392
|
+
|
393
|
+
@register_to_config
|
394
|
+
def __init__(
|
395
|
+
self,
|
396
|
+
patch_size: Tuple[int] = (1, 2, 2),
|
397
|
+
num_attention_heads: int = 16,
|
398
|
+
attention_head_dim: int = 128,
|
399
|
+
in_channels: int = 16,
|
400
|
+
out_channels: int = 16,
|
401
|
+
text_dim: int = 4096,
|
402
|
+
freq_dim: int = 256,
|
403
|
+
ffn_dim: int = 8192,
|
404
|
+
num_layers: int = 32,
|
405
|
+
cross_attn_norm: bool = True,
|
406
|
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
407
|
+
eps: float = 1e-6,
|
408
|
+
image_dim: Optional[int] = None,
|
409
|
+
added_kv_proj_dim: Optional[int] = None,
|
410
|
+
rope_max_seq_len: int = 1024,
|
411
|
+
pos_embed_seq_len: Optional[int] = None,
|
412
|
+
inject_sample_info: bool = False,
|
413
|
+
num_frame_per_block: int = 1,
|
414
|
+
) -> None:
|
415
|
+
super().__init__()
|
416
|
+
|
417
|
+
inner_dim = num_attention_heads * attention_head_dim
|
418
|
+
out_channels = out_channels or in_channels
|
419
|
+
|
420
|
+
# 1. Patch & position embedding
|
421
|
+
self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
422
|
+
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
423
|
+
|
424
|
+
# 2. Condition embeddings
|
425
|
+
# image_embedding_dim=1280 for I2V model
|
426
|
+
self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
|
427
|
+
dim=inner_dim,
|
428
|
+
time_freq_dim=freq_dim,
|
429
|
+
time_proj_dim=inner_dim * 6,
|
430
|
+
text_embed_dim=text_dim,
|
431
|
+
image_embed_dim=image_dim,
|
432
|
+
pos_embed_seq_len=pos_embed_seq_len,
|
433
|
+
)
|
434
|
+
|
435
|
+
# 3. Transformer blocks
|
436
|
+
self.blocks = nn.ModuleList(
|
437
|
+
[
|
438
|
+
SkyReelsV2TransformerBlock(
|
439
|
+
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
440
|
+
)
|
441
|
+
for _ in range(num_layers)
|
442
|
+
]
|
443
|
+
)
|
444
|
+
|
445
|
+
# 4. Output norm & projection
|
446
|
+
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
447
|
+
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
448
|
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
449
|
+
|
450
|
+
if inject_sample_info:
|
451
|
+
self.fps_embedding = nn.Embedding(2, inner_dim)
|
452
|
+
self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
|
453
|
+
|
454
|
+
self.gradient_checkpointing = False
|
455
|
+
|
456
|
+
def forward(
|
457
|
+
self,
|
458
|
+
hidden_states: torch.Tensor,
|
459
|
+
timestep: torch.LongTensor,
|
460
|
+
encoder_hidden_states: torch.Tensor,
|
461
|
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
462
|
+
enable_diffusion_forcing: bool = False,
|
463
|
+
fps: Optional[torch.Tensor] = None,
|
464
|
+
return_dict: bool = True,
|
465
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
466
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
467
|
+
if attention_kwargs is not None:
|
468
|
+
attention_kwargs = attention_kwargs.copy()
|
469
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
470
|
+
else:
|
471
|
+
lora_scale = 1.0
|
472
|
+
|
473
|
+
if USE_PEFT_BACKEND:
|
474
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
475
|
+
scale_lora_layers(self, lora_scale)
|
476
|
+
else:
|
477
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
478
|
+
logger.warning(
|
479
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
480
|
+
)
|
481
|
+
|
482
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
483
|
+
p_t, p_h, p_w = self.config.patch_size
|
484
|
+
post_patch_num_frames = num_frames // p_t
|
485
|
+
post_patch_height = height // p_h
|
486
|
+
post_patch_width = width // p_w
|
487
|
+
|
488
|
+
rotary_emb = self.rope(hidden_states)
|
489
|
+
|
490
|
+
hidden_states = self.patch_embedding(hidden_states)
|
491
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
492
|
+
|
493
|
+
causal_mask = None
|
494
|
+
if self.config.num_frame_per_block > 1:
|
495
|
+
block_num = post_patch_num_frames // self.config.num_frame_per_block
|
496
|
+
range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
|
497
|
+
self.config.num_frame_per_block
|
498
|
+
)
|
499
|
+
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
|
500
|
+
causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
|
501
|
+
causal_mask = causal_mask.repeat(
|
502
|
+
1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
|
503
|
+
)
|
504
|
+
causal_mask = causal_mask.reshape(
|
505
|
+
post_patch_num_frames * post_patch_height * post_patch_width,
|
506
|
+
post_patch_num_frames * post_patch_height * post_patch_width,
|
507
|
+
)
|
508
|
+
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
509
|
+
|
510
|
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
511
|
+
timestep, encoder_hidden_states, encoder_hidden_states_image
|
512
|
+
)
|
513
|
+
|
514
|
+
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
|
515
|
+
|
516
|
+
if encoder_hidden_states_image is not None:
|
517
|
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
518
|
+
|
519
|
+
if self.config.inject_sample_info:
|
520
|
+
fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
|
521
|
+
|
522
|
+
fps_emb = self.fps_embedding(fps)
|
523
|
+
if enable_diffusion_forcing:
|
524
|
+
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
|
525
|
+
timestep.shape[1], 1, 1
|
526
|
+
)
|
527
|
+
else:
|
528
|
+
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
|
529
|
+
|
530
|
+
if enable_diffusion_forcing:
|
531
|
+
b, f = timestep.shape
|
532
|
+
temb = temb.view(b, f, 1, 1, -1)
|
533
|
+
timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
|
534
|
+
temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
|
535
|
+
timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
|
536
|
+
1, 3
|
537
|
+
) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
|
538
|
+
timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
|
539
|
+
|
540
|
+
# 4. Transformer blocks
|
541
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
542
|
+
for block in self.blocks:
|
543
|
+
hidden_states = self._gradient_checkpointing_func(
|
544
|
+
block,
|
545
|
+
hidden_states,
|
546
|
+
encoder_hidden_states,
|
547
|
+
timestep_proj,
|
548
|
+
rotary_emb,
|
549
|
+
causal_mask,
|
550
|
+
)
|
551
|
+
else:
|
552
|
+
for block in self.blocks:
|
553
|
+
hidden_states = block(
|
554
|
+
hidden_states,
|
555
|
+
encoder_hidden_states,
|
556
|
+
timestep_proj,
|
557
|
+
rotary_emb,
|
558
|
+
causal_mask,
|
559
|
+
)
|
560
|
+
|
561
|
+
if temb.dim() == 2:
|
562
|
+
# If temb is 2D, we assume it has time 1-D time embedding values for each batch.
|
563
|
+
# For models:
|
564
|
+
# - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
|
565
|
+
# - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
|
566
|
+
# - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
|
567
|
+
# - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
|
568
|
+
# - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
|
569
|
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
570
|
+
elif temb.dim() == 3:
|
571
|
+
# If temb is 3D, we assume it has 2-D time embedding values for each batch.
|
572
|
+
# Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
|
573
|
+
# For models:
|
574
|
+
# - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
|
575
|
+
# - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
|
576
|
+
# - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
|
577
|
+
shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
|
578
|
+
shift, scale = shift.squeeze(1), scale.squeeze(1)
|
579
|
+
|
580
|
+
# Move the shift and scale tensors to the same device as hidden_states.
|
581
|
+
# When using multi-GPU inference via accelerate these will be on the
|
582
|
+
# first device rather than the last device, which hidden_states ends up
|
583
|
+
# on.
|
584
|
+
shift = shift.to(hidden_states.device)
|
585
|
+
scale = scale.to(hidden_states.device)
|
586
|
+
|
587
|
+
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
588
|
+
|
589
|
+
hidden_states = self.proj_out(hidden_states)
|
590
|
+
|
591
|
+
hidden_states = hidden_states.reshape(
|
592
|
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
593
|
+
)
|
594
|
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
595
|
+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
596
|
+
|
597
|
+
if USE_PEFT_BACKEND:
|
598
|
+
# remove `lora_scale` from each PEFT layer
|
599
|
+
unscale_lora_layers(self, lora_scale)
|
600
|
+
|
601
|
+
if not return_dict:
|
602
|
+
return (output,)
|
603
|
+
|
604
|
+
return Transformer2DModelOutput(sample=output)
|
605
|
+
|
606
|
+
def _set_ar_attention(self, causal_block_size: int):
|
607
|
+
self.register_to_config(num_frame_per_block=causal_block_size)
|