diffsynth-engine 0.5.1.dev2__py3-none-any.whl → 0.5.1.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from .configs import (
3
3
  SDXLPipelineConfig,
4
4
  FluxPipelineConfig,
5
5
  WanPipelineConfig,
6
+ WanSpeech2VideoPipelineConfig,
6
7
  QwenImagePipelineConfig,
7
8
  HunyuanPipelineConfig,
8
9
  SDStateDicts,
@@ -45,6 +46,7 @@ __all__ = [
45
46
  "SDXLPipelineConfig",
46
47
  "FluxPipelineConfig",
47
48
  "WanPipelineConfig",
49
+ "WanSpeech2VideoPipelineConfig",
48
50
  "QwenImagePipelineConfig",
49
51
  "HunyuanPipelineConfig",
50
52
  "SDStateDicts",
@@ -0,0 +1,13 @@
1
+ {
2
+ "patch_size": [1, 2, 2],
3
+ "in_dim": 16,
4
+ "dim": 5120,
5
+ "ffn_dim": 13824,
6
+ "freq_dim": 256,
7
+ "text_dim": 4096,
8
+ "out_dim": 16,
9
+ "num_heads": 40,
10
+ "num_layers": 40,
11
+ "eps": 1e-6,
12
+ "audio_inject_layers": [0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39]
13
+ }
@@ -7,6 +7,7 @@ from .pipeline import (
7
7
  SDXLPipelineConfig,
8
8
  FluxPipelineConfig,
9
9
  WanPipelineConfig,
10
+ WanSpeech2VideoPipelineConfig,
10
11
  QwenImagePipelineConfig,
11
12
  HunyuanPipelineConfig,
12
13
  BaseStateDicts,
@@ -14,6 +15,7 @@ from .pipeline import (
14
15
  SDXLStateDicts,
15
16
  FluxStateDicts,
16
17
  WanStateDicts,
18
+ WanS2VStateDicts,
17
19
  QwenImageStateDicts,
18
20
  )
19
21
  from .controlnet import ControlType, ControlNetParams
@@ -27,6 +29,7 @@ __all__ = [
27
29
  "SDXLPipelineConfig",
28
30
  "FluxPipelineConfig",
29
31
  "WanPipelineConfig",
32
+ "WanSpeech2VideoPipelineConfig",
30
33
  "QwenImagePipelineConfig",
31
34
  "HunyuanPipelineConfig",
32
35
  "BaseStateDicts",
@@ -34,6 +37,7 @@ __all__ = [
34
37
  "SDXLStateDicts",
35
38
  "FluxStateDicts",
36
39
  "WanStateDicts",
40
+ "WanS2VStateDicts",
37
41
  "QwenImageStateDicts",
38
42
  "ControlType",
39
43
  "ControlNetParams",
@@ -184,6 +184,34 @@ class WanPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, Bas
184
184
  init_parallel_config(self)
185
185
 
186
186
 
187
+ @dataclass
188
+ class WanSpeech2VideoPipelineConfig(WanPipelineConfig):
189
+ audio_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
190
+ audio_encoder_dtype: torch.dtype = torch.float32
191
+
192
+ @classmethod
193
+ def basic_config(
194
+ cls,
195
+ model_path: str | os.PathLike | List[str | os.PathLike],
196
+ audio_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
197
+ device: str = "cuda",
198
+ parallelism: int = 1,
199
+ offload_mode: Optional[str] = None,
200
+ ) -> "WanSpeech2VideoPipelineConfig":
201
+ return cls(
202
+ model_path=model_path,
203
+ audio_encoder_path=audio_encoder_path,
204
+ device=device,
205
+ parallelism=parallelism,
206
+ use_cfg_parallel=True if parallelism > 1 else False,
207
+ use_fsdp=True if parallelism > 1 else False,
208
+ offload_mode=offload_mode,
209
+ )
210
+
211
+ def __post_init__(self):
212
+ init_parallel_config(self)
213
+
214
+
187
215
  @dataclass
188
216
  class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
189
217
  model_path: str | os.PathLike | List[str | os.PathLike]
@@ -274,6 +302,14 @@ class WanStateDicts:
274
302
  image_encoder: Optional[Dict[str, torch.Tensor]] = None
275
303
 
276
304
 
305
+ @dataclass
306
+ class WanS2VStateDicts:
307
+ model: Dict[str, torch.Tensor] | Dict[str, Dict[str, torch.Tensor]]
308
+ t5: Dict[str, torch.Tensor]
309
+ vae: Dict[str, torch.Tensor]
310
+ audio_encoder: Dict[str, torch.Tensor]
311
+
312
+
277
313
  @dataclass
278
314
  class QwenImageStateDicts:
279
315
  model: Dict[str, torch.Tensor]
@@ -135,12 +135,13 @@ def attention(
135
135
  flash_attn3_compatible = q.shape[-1] <= FA3_MAX_HEADDIM
136
136
  if attn_impl is None or attn_impl == "auto":
137
137
  if FLASH_ATTN_3_AVAILABLE:
138
- if flash_attn3_compatible:
138
+ if flash_attn3_compatible and attn_mask is None:
139
139
  return flash_attn3(q, k, v, softmax_scale=scale)
140
140
  else:
141
- logger.warning(
142
- f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation"
143
- )
141
+ if not flash_attn3_compatible:
142
+ logger.warning(f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}, will use fallback attention implementation")
143
+ else:
144
+ logger.debug("flash_attn_3 does not support attention mask, will use fallback attention implementation")
144
145
  if XFORMERS_AVAILABLE:
145
146
  return xformers_attn(q, k, v, attn_mask=attn_mask, scale=scale)
146
147
  if SDPA_AVAILABLE:
@@ -156,6 +157,8 @@ def attention(
156
157
  raise RuntimeError(
157
158
  f"head_dim={q.shape[-1]}, but flash_attn_3 only supports head dimension at most {FA3_MAX_HEADDIM}"
158
159
  )
160
+ if attn_mask is not None:
161
+ raise RuntimeError("flash_attn_3 does not support attention mask")
159
162
  return flash_attn3(q, k, v, softmax_scale=scale)
160
163
  if attn_impl == "flash_attn_2":
161
164
  return flash_attn2(q, k, v, softmax_scale=scale)
@@ -0,0 +1,306 @@
1
+ from typing import Tuple, Dict
2
+
3
+ import math
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from diffsynth_engine.models.base import PreTrainedModel
12
+ from diffsynth_engine.models.basic import attention as attention_ops
13
+
14
+
15
+ # ⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇ Wav2Vec2ForCTC ⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇⬇
16
+ class Wav2Vec2Config:
17
+ def __init__(self):
18
+ self.conv_bias = True
19
+ self.conv_dim = [512, 512, 512, 512, 512, 512, 512]
20
+ self.conv_kernel = [10, 3, 3, 3, 3, 2, 2]
21
+ self.conv_stride = [5, 2, 2, 2, 2, 2, 2]
22
+ self.hidden_size = 1024
23
+ self.intermediate_size = 4096
24
+ self.layer_norm_eps = 1e-05
25
+ self.num_attention_heads = 16
26
+ self.num_conv_pos_embedding_groups = 16
27
+ self.num_conv_pos_embeddings = 128
28
+ self.num_feat_extract_layers = 7
29
+ self.num_hidden_layers = 24
30
+
31
+
32
+ class Wav2Vec2LayerNormConvLayer(nn.Module):
33
+ def __init__(self, config: Wav2Vec2Config, layer_id=0, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
34
+ super().__init__()
35
+ self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
36
+ self.out_conv_dim = config.conv_dim[layer_id]
37
+
38
+ self.conv = nn.Conv1d(
39
+ self.in_conv_dim,
40
+ self.out_conv_dim,
41
+ kernel_size=config.conv_kernel[layer_id],
42
+ stride=config.conv_stride[layer_id],
43
+ bias=config.conv_bias,
44
+ device=device,
45
+ dtype=dtype,
46
+ )
47
+ self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True, device=device, dtype=dtype)
48
+
49
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
50
+ hidden_states = self.conv(hidden_states)
51
+ hidden_states = hidden_states.transpose(-2, -1)
52
+ hidden_states = self.layer_norm(hidden_states)
53
+ hidden_states = hidden_states.transpose(-2, -1)
54
+ hidden_states = F.gelu(hidden_states)
55
+ return hidden_states
56
+
57
+
58
+ class Wav2Vec2SamePadLayer(nn.Module):
59
+ def __init__(self, num_conv_pos_embeddings: int):
60
+ super().__init__()
61
+ self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
62
+
63
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
64
+ if self.num_pad_remove > 0:
65
+ hidden_states = hidden_states[:, :, : -self.num_pad_remove]
66
+ return hidden_states
67
+
68
+
69
+ class Wav2Vec2PositionalConvEmbedding(nn.Module):
70
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
71
+ super().__init__()
72
+ self.conv = nn.Conv1d(
73
+ config.hidden_size,
74
+ config.hidden_size,
75
+ kernel_size=config.num_conv_pos_embeddings,
76
+ padding=config.num_conv_pos_embeddings // 2,
77
+ groups=config.num_conv_pos_embedding_groups,
78
+ device=device,
79
+ dtype=dtype,
80
+ )
81
+
82
+ self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2)
83
+ self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
84
+
85
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
86
+ hidden_states = hidden_states.transpose(1, 2)
87
+ hidden_states = self.conv(hidden_states)
88
+ hidden_states = self.padding(hidden_states)
89
+ hidden_states = F.gelu(hidden_states)
90
+ hidden_states = hidden_states.transpose(1, 2)
91
+ return hidden_states
92
+
93
+
94
+ class Wav2Vec2FeedForward(nn.Module):
95
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
96
+ super().__init__()
97
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype)
98
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype)
99
+
100
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101
+ hidden_states = self.intermediate_dense(hidden_states)
102
+ hidden_states = F.gelu(hidden_states)
103
+ hidden_states = self.output_dense(hidden_states)
104
+ return hidden_states
105
+
106
+
107
+ class Wav2Vec2Attention(nn.Module):
108
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
109
+
110
+ def __init__(
111
+ self,
112
+ embed_dim: int,
113
+ num_heads: int,
114
+ bias: bool = True,
115
+ device: str = "cuda:0",
116
+ dtype: torch.dtype = torch.bfloat16,
117
+ ):
118
+ super().__init__()
119
+ self.num_heads = num_heads
120
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
121
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
122
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
123
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
124
+
125
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
126
+ """Input shape: Batch x Time x Channel"""
127
+ q = self.q_proj(hidden_states)
128
+ k = self.k_proj(hidden_states)
129
+ v = self.v_proj(hidden_states)
130
+ q = rearrange(q, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
131
+ k = rearrange(k, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
132
+ v = rearrange(v, "b s (h d) -> b s h d", h=self.num_heads).contiguous()
133
+ attn_output = attention_ops.attention(q=q, k=k, v=v)
134
+ attn_output = rearrange(attn_output, "b s h d -> b s (h d)").contiguous()
135
+ attn_output = self.out_proj(attn_output)
136
+ return attn_output
137
+
138
+
139
+ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
140
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
141
+ super().__init__()
142
+ self.attention = Wav2Vec2Attention(
143
+ embed_dim=config.hidden_size,
144
+ num_heads=config.num_attention_heads,
145
+ device=device,
146
+ dtype=dtype,
147
+ )
148
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
149
+ self.feed_forward = Wav2Vec2FeedForward(config, device=device, dtype=dtype)
150
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
151
+
152
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
153
+ attn_residual = hidden_states
154
+ hidden_states = self.layer_norm(hidden_states)
155
+ hidden_states = self.attention(hidden_states)
156
+ hidden_states = attn_residual + hidden_states
157
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
158
+ return hidden_states
159
+
160
+
161
+ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
162
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
163
+ super().__init__()
164
+ self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config, device=device, dtype=dtype)
165
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype)
166
+ self.layers = nn.ModuleList(
167
+ [
168
+ Wav2Vec2EncoderLayerStableLayerNorm(config, device=device, dtype=dtype)
169
+ for _ in range(config.num_hidden_layers)
170
+ ]
171
+ )
172
+
173
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
174
+ position_embeddings = self.pos_conv_embed(hidden_states)
175
+ hidden_states = hidden_states + position_embeddings
176
+ all_hidden_states = ()
177
+ for layer in self.layers:
178
+ all_hidden_states = all_hidden_states + (hidden_states,)
179
+ hidden_states = layer(hidden_states)
180
+ hidden_states = self.layer_norm(hidden_states)
181
+ all_hidden_states = all_hidden_states + (hidden_states,)
182
+ return all_hidden_states
183
+
184
+
185
+ class Wav2Vec2FeatureEncoder(nn.Module):
186
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
187
+ super().__init__()
188
+ self.conv_layers = nn.ModuleList(
189
+ [
190
+ Wav2Vec2LayerNormConvLayer(config, layer_id=i, device=device, dtype=dtype)
191
+ for i in range(config.num_feat_extract_layers)
192
+ ]
193
+ )
194
+
195
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
196
+ hidden_states = input_values[:, None]
197
+ for conv_layer in self.conv_layers:
198
+ hidden_states = conv_layer(hidden_states)
199
+ return hidden_states
200
+
201
+
202
+ class Wav2Vec2FeatureProjection(nn.Module):
203
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
204
+ super().__init__()
205
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps, device=device, dtype=dtype)
206
+ self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size, device=device, dtype=dtype)
207
+
208
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
209
+ # non-projected hidden states are needed for quantization
210
+ norm_hidden_states = self.layer_norm(hidden_states)
211
+ hidden_states = self.projection(norm_hidden_states)
212
+ return hidden_states, norm_hidden_states
213
+
214
+
215
+ class Wav2Vec2StateDictConverter:
216
+ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
217
+ new_state_dict = {}
218
+ for k, v in state_dict.items():
219
+ if k.startswith("wav2vec2.") and "masked_spec_embed" not in k:
220
+ new_state_dict[k[len("wav2vec2.") :]] = v
221
+ return new_state_dict
222
+
223
+
224
+ class Wav2Vec2Model(PreTrainedModel):
225
+ converter = Wav2Vec2StateDictConverter()
226
+ _supports_parallelization = False
227
+
228
+ def __init__(self, config: Wav2Vec2Config, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
229
+ super().__init__()
230
+ self.feature_extractor = Wav2Vec2FeatureEncoder(config, device=device, dtype=dtype)
231
+ self.feature_projection = Wav2Vec2FeatureProjection(config, device=device, dtype=dtype)
232
+ self.encoder = Wav2Vec2EncoderStableLayerNorm(config, device=device, dtype=dtype)
233
+
234
+ def forward(self, input_values: torch.Tensor) -> torch.Tensor:
235
+ extract_features = self.feature_extractor(input_values).transpose(1, 2)
236
+ hidden_states, _ = self.feature_projection(extract_features)
237
+ return self.encoder(hidden_states)
238
+
239
+
240
+ # ⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆ Wav2Vec2ForCTC ⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆⬆
241
+
242
+
243
+ def get_sample_indices(original_fps: int, target_fps: int, total_frames: int, num_samples: int) -> np.ndarray:
244
+ required_duration = num_samples / target_fps
245
+ if required_duration > total_frames / original_fps:
246
+ raise ValueError("required_duration must be less than video length")
247
+
248
+ time_points = np.linspace(0, required_duration, num_samples, endpoint=False)
249
+ frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
250
+ frame_indices = np.clip(frame_indices, 0, total_frames - 1)
251
+ return frame_indices
252
+
253
+
254
+ def linear_interpolation(features: torch.Tensor, input_fps: int, output_fps: int) -> torch.Tensor:
255
+ """
256
+ features: shape=[1, T, 512]
257
+ input_fps: fps for audio, f_a
258
+ output_fps: fps for video, f_m
259
+ output_len: video length
260
+ """
261
+ features = features.transpose(1, 2) # [1, 512, T]
262
+ seq_len = features.shape[2] / float(input_fps) # T/f_a
263
+ output_len = int(seq_len * output_fps) # f_m*T/f_a
264
+ output_features = F.interpolate(
265
+ features, size=output_len, align_corners=True, mode="linear"
266
+ ) # [1, 512, output_len]
267
+ return output_features.transpose(1, 2) # [1, output_len, 512]
268
+
269
+
270
+ def extract_audio_feat(audio_input: torch.Tensor, model: Wav2Vec2Model, dtype=torch.float32, device="cuda:0") -> torch.Tensor:
271
+ video_rate = 30
272
+ input_values = (audio_input - audio_input.mean(dim=1, keepdim=True)) / torch.sqrt(audio_input.var(dim=1, keepdim=True) + 1e-7)
273
+ feat = torch.cat(model(input_values.to(device)))
274
+ feat = linear_interpolation(feat, input_fps=50, output_fps=video_rate)
275
+ return feat.to(dtype) # Encoding for the motion
276
+
277
+
278
+ def get_audio_embed_bucket_fps(
279
+ audio_embed: torch.Tensor, num_frames_per_batch: int, fps: int = 16
280
+ ) -> Tuple[torch.Tensor, int]:
281
+ video_rate = 30
282
+ scale = video_rate / fps
283
+ num_layers, num_audio_frames, audio_dim = audio_embed.shape
284
+ max_num_batches = int(num_audio_frames / (num_frames_per_batch * scale)) + 1
285
+ num_buckets = max_num_batches * num_frames_per_batch
286
+ num_audio_padding = math.ceil(max_num_batches * num_frames_per_batch / fps * video_rate) - num_audio_frames
287
+ batch_indices = get_sample_indices(
288
+ original_fps=video_rate,
289
+ target_fps=fps,
290
+ total_frames=num_audio_frames + num_audio_padding,
291
+ num_samples=num_buckets,
292
+ )
293
+ batch_audio_embed = []
294
+ audio_sample_stride = int(video_rate / fps)
295
+ for batch_idx in batch_indices:
296
+ if batch_idx < num_audio_frames:
297
+ chosen_idx = list(range(batch_idx, batch_idx + audio_sample_stride, audio_sample_stride))
298
+ chosen_idx = [0 if c < 0 else c for c in chosen_idx]
299
+ chosen_idx = [num_audio_frames - 1 if c >= num_audio_frames else c for c in chosen_idx]
300
+ frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1)
301
+ else:
302
+ frame_audio_embed = torch.zeros([num_layers, audio_dim], device=audio_embed.device)
303
+ batch_audio_embed.append(frame_audio_embed)
304
+ batch_audio_embed = torch.stack(batch_audio_embed, dim=0)
305
+
306
+ return batch_audio_embed, max_num_batches
@@ -6,6 +6,7 @@ from typing import Any, Dict, Tuple, Optional
6
6
  from einops import rearrange
7
7
 
8
8
  from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
9
+ from diffsynth_engine.models.basic.attention import attention
9
10
  from diffsynth_engine.models.basic import attention as attention_ops
10
11
  from diffsynth_engine.models.basic.transformer_helper import RMSNorm
11
12
  from diffsynth_engine.utils.constants import (
@@ -18,6 +19,7 @@ from diffsynth_engine.utils.constants import (
18
19
  WAN2_2_DIT_T2V_A14B_CONFIG_FILE,
19
20
  )
20
21
  from diffsynth_engine.utils.gguf import gguf_inference
22
+ from diffsynth_engine.utils.fp8_linear import fp8_inference
21
23
  from diffsynth_engine.utils.parallel import (
22
24
  cfg_parallel,
23
25
  cfg_parallel_unshard,
@@ -142,12 +144,12 @@ class CrossAttention(nn.Module):
142
144
  k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
143
145
  v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
144
146
 
145
- x = attention_ops.attention(q, k, v, **self.attn_kwargs).flatten(2)
147
+ x = attention(q, k, v, **self.attn_kwargs).flatten(2)
146
148
  if self.has_image_input:
147
149
  k_img, v_img = self.norm_k_img(self.k_img(img)), self.v_img(img)
148
150
  k_img = rearrange(k_img, "b s (n d) -> b s n d", n=num_heads)
149
151
  v_img = rearrange(v_img, "b s (n d) -> b s n d", n=num_heads)
150
- y = attention_ops.attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
152
+ y = attention(q, k_img, v_img, **self.attn_kwargs).flatten(2)
151
153
  x = x + y
152
154
  return self.o(x)
153
155
 
@@ -343,8 +345,10 @@ class WanDiT(PreTrainedModel):
343
345
  clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
344
346
  y: Optional[torch.Tensor] = None, # vae_encoder(img)
345
347
  ):
348
+ fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
346
349
  use_cfg = x.shape[0] > 1
347
350
  with (
351
+ fp8_inference(fp8_linear_enabled),
348
352
  gguf_inference(),
349
353
  cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
350
354
  ):