diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/__init__.py +2 -0
- diffsynth_engine/conf/models/wan/dit/wan2.2-s2v-14b.json +13 -0
- diffsynth_engine/configs/__init__.py +4 -0
- diffsynth_engine/configs/pipeline.py +36 -0
- diffsynth_engine/models/basic/attention.py +7 -4
- diffsynth_engine/models/wan/wan_audio_encoder.py +306 -0
- diffsynth_engine/models/wan/wan_dit.py +6 -2
- diffsynth_engine/models/wan/wan_s2v_dit.py +567 -0
- diffsynth_engine/pipelines/__init__.py +2 -0
- diffsynth_engine/pipelines/wan_s2v.py +685 -0
- diffsynth_engine/utils/constants.py +1 -0
- diffsynth_engine/utils/image.py +7 -0
- diffsynth_engine/utils/video.py +26 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/METADATA +3 -1
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/RECORD +18 -14
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.5.1.dev2.dist-info → diffsynth_engine-0.5.1.dev4.dist-info}/top_level.txt +0 -0
diffsynth_engine/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ from .configs import (
|
|
|
3
3
|
SDXLPipelineConfig,
|
|
4
4
|
FluxPipelineConfig,
|
|
5
5
|
WanPipelineConfig,
|
|
6
|
+
WanSpeech2VideoPipelineConfig,
|
|
6
7
|
QwenImagePipelineConfig,
|
|
7
8
|
HunyuanPipelineConfig,
|
|
8
9
|
SDStateDicts,
|
|
@@ -45,6 +46,7 @@ __all__ = [
|
|
|
45
46
|
"SDXLPipelineConfig",
|
|
46
47
|
"FluxPipelineConfig",
|
|
47
48
|
"WanPipelineConfig",
|
|
49
|
+
"WanSpeech2VideoPipelineConfig",
|
|
48
50
|
"QwenImagePipelineConfig",
|
|
49
51
|
"HunyuanPipelineConfig",
|
|
50
52
|
"SDStateDicts",
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
{
|
|
2
|
+
"patch_size": [1, 2, 2],
|
|
3
|
+
"in_dim": 16,
|
|
4
|
+
"dim": 5120,
|
|
5
|
+
"ffn_dim": 13824,
|
|
6
|
+
"freq_dim": 256,
|
|
7
|
+
"text_dim": 4096,
|
|
8
|
+
"out_dim": 16,
|
|
9
|
+
"num_heads": 40,
|
|
10
|
+
"num_layers": 40,
|
|
11
|
+
"eps": 1e-6,
|
|
12
|
+
"audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39]
|
|
13
|
+
}
|
|
@@ -7,6 +7,7 @@ from .pipeline import (
|
|
|
7
7
|
SDXLPipelineConfig,
|
|
8
8
|
FluxPipelineConfig,
|
|
9
9
|
WanPipelineConfig,
|
|
10
|
+
WanSpeech2VideoPipelineConfig,
|
|
10
11
|
QwenImagePipelineConfig,
|
|
11
12
|
HunyuanPipelineConfig,
|
|
12
13
|
BaseStateDicts,
|
|
@@ -14,6 +15,7 @@ from .pipeline import (
|
|
|
14
15
|
SDXLStateDicts,
|
|
15
16
|
FluxStateDicts,
|
|
16
17
|
WanStateDicts,
|
|
18
|
+
WanS2VStateDicts,
|
|
17
19
|
QwenImageStateDicts,
|
|
18
20
|
)
|
|
19
21
|
from .controlnet import ControlType, ControlNetParams
|
|
@@ -27,6 +29,7 @@ __all__ = [
|
|
|
27
29
|
"SDXLPipelineConfig",
|
|
28
30
|
"FluxPipelineConfig",
|
|
29
31
|
"WanPipelineConfig",
|
|
32
|
+
"WanSpeech2VideoPipelineConfig",
|
|
30
33
|
"QwenImagePipelineConfig",
|
|
31
34
|
"HunyuanPipelineConfig",
|
|
32
35
|
"BaseStateDicts",
|
|
@@ -34,6 +37,7 @@ __all__ = [
|
|
|
34
37
|
"SDXLStateDicts",
|
|
35
38
|
"FluxStateDicts",
|
|
36
39
|
"WanStateDicts",
|
|
40
|
+
"WanS2VStateDicts",
|
|
37
41
|
"QwenImageStateDicts",
|
|
38
42
|
"ControlType",
|
|
39
43
|
"ControlNetParams",
|
|
@@ -184,6 +184,34 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
|
|
|
184
184
|
init_parallel_config(self)
|
|
185
185
|
|
|
186
186
|
|
|
187
|
+
@dataclass
|
|
188
|
+
class WanSpeech2VideoPipelineConfig(WanPipelineConfig):
|
|
189
|
+
audio_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
|
|
190
|
+
audio_encoder_dtype: torch.dtype = torch.float32
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def basic_config(
|
|
194
|
+
cls,
|
|
195
|
+
model_path: str | os.PathLike | List[str | os.PathLike],
|
|
196
|
+
audio_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
|
|
197
|
+
device: str = "cuda",
|
|
198
|
+
parallelism: int = 1,
|
|
199
|
+
offload_mode: Optional[str] = None,
|
|
200
|
+
) -> "WanSpeech2VideoPipelineConfig":
|
|
201
|
+
return cls(
|
|
202
|
+
model_path=model_path,
|
|
203
|
+
audio_encoder_path=audio_encoder_path,
|
|
204
|
+
device=device,
|
|
205
|
+
parallelism=parallelism,
|
|
206
|
+
use_cfg_parallel=True if parallelism > 1 else False,
|
|
207
|
+
use_fsdp=True if parallelism > 1 else False,
|
|
208
|
+
offload_mode=offload_mode,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
def __post_init__(self):
|
|
212
|
+
init_parallel_config(self)
|
|
213
|
+
|
|
214
|
+
|
|
187
215
|
@dataclass
|
|
188
216
|
class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
|
|
189
217
|
model_path: str | os.PathLike | List[str | os.PathLike]
|
|
@@ -274,6 +302,14 @@ class WanStateDicts:
|
|
|
274
302
|
image_encoder: Optional[Dict[str, torch.Tensor]] = None
|
|
275
303
|
|
|
276
304
|
|
|
305
|
+
@dataclass
|
|
306
|
+
class WanS2VStateDicts:
|
|
307
|
+
model: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]
|
|
308
|
+
t5: Dict[str, torch.Tensor]
|
|
309
|
+
vae: Dict[str, torch.Tensor]
|
|
310
|
+
audio_encoder: Dict[str, torch.Tensor]
|
|
311
|
+
|
|
312
|
+
|
|
277
313
|
@dataclass
|
|
278
314
|
class QwenImageStateDicts:
|
|
279
315
|
model: Dict[str, torch.Tensor]
|
|
@@ -135,12 +135,13 @@ def attention(
|
|
|
135
135
|
flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
|
|
136
136
|
if attn_impl is None or attn_impl == "auto":
|
|
137
137
|
if FLASH_ATTN_3_AVAILABLE:
|
|
138
|
-
if flash_attn3_compatible:
|
|
138
|
+
if flash_attn3_compatible and attn_mask is None:
|
|
139
139
|
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
140
140
|
else:
|
|
141
|
-
|
|
142
|
-
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
|
|
143
|
-
|
|
141
|
+
if not flash_attn3_compatible:
|
|
142
|
+
logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
|
|
143
|
+
else:
|
|
144
|
+
logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
|
|
144
145
|
if XFORMERS_AVAILABLE:
|
|
145
146
|
return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
|
|
146
147
|
if SDPA_AVAILABLE:
|
|
@@ -156,6 +157,8 @@ def attention(
|
|
|
156
157
|
raise RuntimeError(
|
|
157
158
|
f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
|
|
158
159
|
)
|
|
160
|
+
if attn_mask is not None:
|
|
161
|
+
raise RuntimeError("flash_attn_3 does not support attention mask")
|
|
159
162
|
return flash_attn3(q, k, v, softmax_scale=scale)
|
|
160
163
|
if attn_impl == "flash_attn_2":
|
|
161
164
|
return flash_attn2(q, k, v, softmax_scale=scale)
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
from typing import Tuple, Dict
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
|
|
11
|
+
from diffsynth_engine.models.base import PreTrainedModel
|
|
12
|
+
from diffsynth_engine.models.basic import attention as attention_ops
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇ Wav2Vec2ForCTC ⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇
|
|
16
|
+
class Wav2Vec2Config:
|
|
17
|
+
def __init__(self):
|
|
18
|
+
self.conv_bias = True
|
|
19
|
+
self.conv_dim = [512, 512, 512, 512, 512, 512, 512]
|
|
20
|
+
self.conv_kernel = [10, 3, 3, 3, 3, 2, 2]
|
|
21
|
+
self.conv_stride = [5, 2, 2, 2, 2, 2, 2]
|
|
22
|
+
self.hidden_size = 1024
|
|
23
|
+
self.intermediate_size = 4096
|
|
24
|
+
self.layer_norm_eps = 1e-05
|
|
25
|
+
self.num_attention_heads = 16
|
|
26
|
+
self.num_conv_pos_embedding_groups = 16
|
|
27
|
+
self.num_conv_pos_embeddings = 128
|
|
28
|
+
self.num_feat_extract_layers = 7
|
|
29
|
+
self.num_hidden_layers = 24
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Wav2Vec2LayerNormConvLayer(nn.Module):
|
|
33
|
+
def __init__(self, config: Wav2Vec2Config, layer_id=0, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
|
|
36
|
+
self.out_conv_dim = config.conv_dim[layer_id]
|
|
37
|
+
|
|
38
|
+
self.conv = nn.Conv1d(
|
|
39
|
+
self.in_conv_dim,
|
|
40
|
+
self.out_conv_dim,
|
|
41
|
+
kernel_size=config.conv_kernel[layer_id],
|
|
42
|
+
stride=config.conv_stride[layer_id],
|
|
43
|
+
bias=config.conv_bias,
|
|
44
|
+
device=device,
|
|
45
|
+
dtype=dtype,
|
|
46
|
+
)
|
|
47
|
+
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True, device=device, dtype=dtype)
|
|
48
|
+
|
|
49
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
hidden_states = self.conv(hidden_states)
|
|
51
|
+
hidden_states = hidden_states.transpose(-2, -1)
|
|
52
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
53
|
+
hidden_states = hidden_states.transpose(-2, -1)
|
|
54
|
+
hidden_states = F.gelu(hidden_states)
|
|
55
|
+
return hidden_states
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Wav2Vec2SamePadLayer(nn.Module):
|
|
59
|
+
def __init__(self, num_conv_pos_embeddings: int):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
|
|
62
|
+
|
|
63
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
if self.num_pad_remove > 0:
|
|
65
|
+
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
|
|
66
|
+
return hidden_states
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Wav2Vec2PositionalConvEmbedding(nn.Module):
|
|
70
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.conv = nn.Conv1d(
|
|
73
|
+
config.hidden_size,
|
|
74
|
+
config.hidden_size,
|
|
75
|
+
kernel_size=config.num_conv_pos_embeddings,
|
|
76
|
+
padding=config.num_conv_pos_embeddings // 2,
|
|
77
|
+
groups=config.num_conv_pos_embedding_groups,
|
|
78
|
+
device=device,
|
|
79
|
+
dtype=dtype,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
|
|
83
|
+
self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
|
|
84
|
+
|
|
85
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
hidden_states = hidden_states.transpose(1, 2)
|
|
87
|
+
hidden_states = self.conv(hidden_states)
|
|
88
|
+
hidden_states = self.padding(hidden_states)
|
|
89
|
+
hidden_states = F.gelu(hidden_states)
|
|
90
|
+
hidden_states = hidden_states.transpose(1, 2)
|
|
91
|
+
return hidden_states
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class Wav2Vec2FeedForward(nn.Module):
|
|
95
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
96
|
+
super().__init__()
|
|
97
|
+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
|
|
98
|
+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
|
|
99
|
+
|
|
100
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
hidden_states = self.intermediate_dense(hidden_states)
|
|
102
|
+
hidden_states = F.gelu(hidden_states)
|
|
103
|
+
hidden_states = self.output_dense(hidden_states)
|
|
104
|
+
return hidden_states
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class Wav2Vec2Attention(nn.Module):
|
|
108
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
embed_dim: int,
|
|
113
|
+
num_heads: int,
|
|
114
|
+
bias: bool = True,
|
|
115
|
+
device: str = "cuda:0",
|
|
116
|
+
dtype: torch.dtype = torch.bfloat16,
|
|
117
|
+
):
|
|
118
|
+
super().__init__()
|
|
119
|
+
self.num_heads = num_heads
|
|
120
|
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
121
|
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
122
|
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
123
|
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
|
124
|
+
|
|
125
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
"""Input shape: Batch x Time x Channel"""
|
|
127
|
+
q = self.q_proj(hidden_states)
|
|
128
|
+
k = self.k_proj(hidden_states)
|
|
129
|
+
v = self.v_proj(hidden_states)
|
|
130
|
+
q = rearrange(q, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
|
|
131
|
+
k = rearrange(k, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
|
|
132
|
+
v = rearrange(v, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
|
|
133
|
+
attn_output = attention_ops.attention(q=q, k=k, v=v)
|
|
134
|
+
attn_output = rearrange(attn_output, "b s h d -> b s (h d)").contiguous()
|
|
135
|
+
attn_output = self.out_proj(attn_output)
|
|
136
|
+
return attn_output
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
|
140
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.attention = Wav2Vec2Attention(
|
|
143
|
+
embed_dim=config.hidden_size,
|
|
144
|
+
num_heads=config.num_attention_heads,
|
|
145
|
+
device=device,
|
|
146
|
+
dtype=dtype,
|
|
147
|
+
)
|
|
148
|
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
149
|
+
self.feed_forward = Wav2Vec2FeedForward(config, device=device, dtype=dtype)
|
|
150
|
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
151
|
+
|
|
152
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
153
|
+
attn_residual = hidden_states
|
|
154
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
155
|
+
hidden_states = self.attention(hidden_states)
|
|
156
|
+
hidden_states = attn_residual + hidden_states
|
|
157
|
+
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
|
158
|
+
return hidden_states
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|
162
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
163
|
+
super().__init__()
|
|
164
|
+
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config, device=device, dtype=dtype)
|
|
165
|
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
166
|
+
self.layers = nn.ModuleList(
|
|
167
|
+
[
|
|
168
|
+
Wav2Vec2EncoderLayerStableLayerNorm(config, device=device, dtype=dtype)
|
|
169
|
+
for _ in range(config.num_hidden_layers)
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
174
|
+
position_embeddings = self.pos_conv_embed(hidden_states)
|
|
175
|
+
hidden_states = hidden_states + position_embeddings
|
|
176
|
+
all_hidden_states = ()
|
|
177
|
+
for layer in self.layers:
|
|
178
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
179
|
+
hidden_states = layer(hidden_states)
|
|
180
|
+
hidden_states = self.layer_norm(hidden_states)
|
|
181
|
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
182
|
+
return all_hidden_states
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class Wav2Vec2FeatureEncoder(nn.Module):
|
|
186
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
187
|
+
super().__init__()
|
|
188
|
+
self.conv_layers = nn.ModuleList(
|
|
189
|
+
[
|
|
190
|
+
Wav2Vec2LayerNormConvLayer(config, layer_id=i, device=device, dtype=dtype)
|
|
191
|
+
for i in range(config.num_feat_extract_layers)
|
|
192
|
+
]
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
|
196
|
+
hidden_states = input_values[:, None]
|
|
197
|
+
for conv_layer in self.conv_layers:
|
|
198
|
+
hidden_states = conv_layer(hidden_states)
|
|
199
|
+
return hidden_states
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class Wav2Vec2FeatureProjection(nn.Module):
|
|
203
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
204
|
+
super().__init__()
|
|
205
|
+
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps, device=device, dtype=dtype)
|
|
206
|
+
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size, device=device, dtype=dtype)
|
|
207
|
+
|
|
208
|
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
209
|
+
# non-projected hidden states are needed for quantization
|
|
210
|
+
norm_hidden_states = self.layer_norm(hidden_states)
|
|
211
|
+
hidden_states = self.projection(norm_hidden_states)
|
|
212
|
+
return hidden_states, norm_hidden_states
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class Wav2Vec2StateDictConverter:
|
|
216
|
+
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
217
|
+
new_state_dict = {}
|
|
218
|
+
for k, v in state_dict.items():
|
|
219
|
+
if k.startswith("wav2vec2.") and "masked_spec_embed" not in k:
|
|
220
|
+
new_state_dict[k[len("wav2vec2.") :]] = v
|
|
221
|
+
return new_state_dict
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class Wav2Vec2Model(PreTrainedModel):
|
|
225
|
+
converter = Wav2Vec2StateDictConverter()
|
|
226
|
+
_supports_parallelization = False
|
|
227
|
+
|
|
228
|
+
def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
|
|
229
|
+
super().__init__()
|
|
230
|
+
self.feature_extractor = Wav2Vec2FeatureEncoder(config, device=device, dtype=dtype)
|
|
231
|
+
self.feature_projection = Wav2Vec2FeatureProjection(config, device=device, dtype=dtype)
|
|
232
|
+
self.encoder = Wav2Vec2EncoderStableLayerNorm(config, device=device, dtype=dtype)
|
|
233
|
+
|
|
234
|
+
def forward(self, input_values: torch.Tensor) -> torch.Tensor:
|
|
235
|
+
extract_features = self.feature_extractor(input_values).transpose(1, 2)
|
|
236
|
+
hidden_states, _ = self.feature_projection(extract_features)
|
|
237
|
+
return self.encoder(hidden_states)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# ⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆ Wav2Vec2ForCTC ⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def get_sample_indices(original_fps: int, target_fps: int, total_frames: int, num_samples: int) -> np.ndarray:
|
|
244
|
+
required_duration = num_samples / target_fps
|
|
245
|
+
if required_duration > total_frames / original_fps:
|
|
246
|
+
raise ValueError("required_duration must be less than video length")
|
|
247
|
+
|
|
248
|
+
time_points = np.linspace(0, required_duration, num_samples, endpoint=False)
|
|
249
|
+
frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
|
|
250
|
+
frame_indices = np.clip(frame_indices, 0, total_frames - 1)
|
|
251
|
+
return frame_indices
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def linear_interpolation(features: torch.Tensor, input_fps: int, output_fps: int) -> torch.Tensor:
|
|
255
|
+
"""
|
|
256
|
+
features: shape=[1, T, 512]
|
|
257
|
+
input_fps: fps for audio, f_a
|
|
258
|
+
output_fps: fps for video, f_m
|
|
259
|
+
output_len: video length
|
|
260
|
+
"""
|
|
261
|
+
features = features.transpose(1, 2) # [1, 512, T]
|
|
262
|
+
seq_len = features.shape[2] / float(input_fps) # T/f_a
|
|
263
|
+
output_len = int(seq_len * output_fps) # f_m*T/f_a
|
|
264
|
+
output_features = F.interpolate(
|
|
265
|
+
features, size=output_len, align_corners=True, mode="linear"
|
|
266
|
+
) # [1, 512, output_len]
|
|
267
|
+
return output_features.transpose(1, 2) # [1, output_len, 512]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def extract_audio_feat(audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0") -> torch.Tensor:
|
|
271
|
+
video_rate = 30
|
|
272
|
+
input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(audio_input.var(dim=1, keepdim=True) + 1e-7)
|
|
273
|
+
feat = torch.cat(model(input_values.to(device)))
|
|
274
|
+
feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
|
|
275
|
+
return feat.to(dtype) # Encoding for the motion
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_audio_embed_bucket_fps(
|
|
279
|
+
audio_embed: torch.Tensor, num_frames_per_batch: int, fps: int = 16
|
|
280
|
+
) -> Tuple[torch.Tensor, int]:
|
|
281
|
+
video_rate = 30
|
|
282
|
+
scale = video_rate / fps
|
|
283
|
+
num_layers, num_audio_frames, audio_dim = audio_embed.shape
|
|
284
|
+
max_num_batches = int(num_audio_frames / (num_frames_per_batch * scale)) + 1
|
|
285
|
+
num_buckets = max_num_batches * num_frames_per_batch
|
|
286
|
+
num_audio_padding = math.ceil(max_num_batches * num_frames_per_batch / fps * video_rate) - num_audio_frames
|
|
287
|
+
batch_indices = get_sample_indices(
|
|
288
|
+
original_fps=video_rate,
|
|
289
|
+
target_fps=fps,
|
|
290
|
+
total_frames=num_audio_frames + num_audio_padding,
|
|
291
|
+
num_samples=num_buckets,
|
|
292
|
+
)
|
|
293
|
+
batch_audio_embed = []
|
|
294
|
+
audio_sample_stride = int(video_rate / fps)
|
|
295
|
+
for batch_idx in batch_indices:
|
|
296
|
+
if batch_idx < num_audio_frames:
|
|
297
|
+
chosen_idx = list(range(batch_idx, batch_idx + audio_sample_stride, audio_sample_stride))
|
|
298
|
+
chosen_idx = [0 if c < 0 else c for c in chosen_idx]
|
|
299
|
+
chosen_idx = [num_audio_frames - 1 if c >= num_audio_frames else c for c in chosen_idx]
|
|
300
|
+
frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
|
|
301
|
+
else:
|
|
302
|
+
frame_audio_embed = torch.zeros([num_layers, audio_dim], device=audio_embed.device)
|
|
303
|
+
batch_audio_embed.append(frame_audio_embed)
|
|
304
|
+
batch_audio_embed = torch.stack(batch_audio_embed, dim=0)
|
|
305
|
+
|
|
306
|
+
return batch_audio_embed, max_num_batches
|
|
@@ -6,6 +6,7 @@ from typing import Any, Dict, Tuple, Optional
|
|
|
6
6
|
from einops import rearrange
|
|
7
7
|
|
|
8
8
|
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
|
|
9
|
+
from diffsynth_engine.models.basic.attention import attention
|
|
9
10
|
from diffsynth_engine.models.basic import attention as attention_ops
|
|
10
11
|
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
|
|
11
12
|
from diffsynth_engine.utils.constants import (
|
|
@@ -18,6 +19,7 @@ from diffsynth_engine.utils.constants import (
|
|
|
18
19
|
WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
|
|
19
20
|
)
|
|
20
21
|
from diffsynth_engine.utils.gguf import gguf_inference
|
|
22
|
+
from diffsynth_engine.utils.fp8_linear import fp8_inference
|
|
21
23
|
from diffsynth_engine.utils.parallel import (
|
|
22
24
|
cfg_parallel,
|
|
23
25
|
cfg_parallel_unshard,
|
|
@@ -142,12 +144,12 @@ class CrossAttention(nn.Module):
|
|
|
142
144
|
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
|
143
145
|
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
|
144
146
|
|
|
145
|
-
x =
|
|
147
|
+
x = attention(q, k, v, **self.attn_kwargs).flatten(2)
|
|
146
148
|
if self.has_image_input:
|
|
147
149
|
k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
|
|
148
150
|
k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
|
|
149
151
|
v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
|
|
150
|
-
y =
|
|
152
|
+
y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
|
|
151
153
|
x = x + y
|
|
152
154
|
return self.o(x)
|
|
153
155
|
|
|
@@ -343,8 +345,10 @@ class WanDiT(PreTrainedModel):
|
|
|
343
345
|
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
|
|
344
346
|
y: Optional[torch.Tensor] = None, # vae_encoder(img)
|
|
345
347
|
):
|
|
348
|
+
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
|
|
346
349
|
use_cfg = x.shape[0] > 1
|
|
347
350
|
with (
|
|
351
|
+
fp8_inference(fp8_linear_enabled),
|
|
348
352
|
gguf_inference(),
|
|
349
353
|
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
|
|
350
354
|
):
|